blob: 62610f89860ab02a2668716e7a9872b245eff526 [file] [log] [blame]
import collections
import dataclasses
import functools
import inspect
import typing
from typing import Any, Callable, Collection, Dict, Tuple, Union
import typing_extensions
import typing_inspect
from hamilton import htypes, node, registry
from hamilton.dev_utils import deprecation
from hamilton.function_modifiers import base
from hamilton.function_modifiers.dependencies import (
ParametrizedDependency,
ParametrizedDependencySource,
source,
value,
)
"""Decorators that enables DRY code by expanding one node into many"""
class parameterize(base.NodeExpander):
"""Decorator to use to create many functions.
Expands a single function into n, each of which correspond to a function in which the parameter value is replaced\
either by:
#. A specified literal value, denoted value('literal_value').
#. The output from a specified upstream function (i.e. node), denoted source('upstream_function_name').
Note that ``parameterize`` can take the place of ``@parameterize_sources`` or ``@parameterize_values`` decorators \
below. In fact, they delegate to this!
Examples expressing different syntax:
.. code-block:: python
@parameterize(
# tuple of assignments (consisting of literals/upstream specifications), and docstring.
replace_no_parameters=({}, 'fn with no parameters replaced'),
)
def no_param_function() -> Any:
...
@parameterize(
# tuple of assignments (consisting of literals/upstream specifications), and docstring.
replace_just_upstream_parameter=(
{'upstream_source': source('foo_source')},
'fn with upstream_parameter set to node foo'
),
)
def param_is_upstream_function(upstream_source: Any) -> Any:
'''Doc string that can also be parameterized: {upstream_source}.'''
...
@parameterize(
replace_just_literal_parameter={'literal_parameter': value('bar')},
)
def param_is_literal_value(literal_parameter: Any) -> Any:
'''Doc string that can also be parameterized: {literal_parameter}.'''
...
@parameterize(
replace_both_parameters={
'upstream_parameter': source('foo_source'),
'literal_parameter': value('bar')
}
)
def concat(upstream_parameter: Any, literal_parameter: str) -> Any:
'''Adding {literal_parameter} to {upstream_parameter} to create {output_name}.'''
return upstream_parameter + literal_parameter
You also have the capability to "group" parameters, which will combine them into a list.
.. code-block:: python
@parameterize(
a_plus_b_plus_c={
'to_concat' : group(source('a'), value('b'), source('c'))
}
)
def concat(to_concat: List[str]) -> Any:
'''Adding {literal_parameter} to {upstream_parameter} to create {output_name}.'''
return sum(to_concat, '')
"""
RESERVED_KWARG = "output_name"
# This is a kwarg that replaces it with the name of the function
# Double underscore means it will not be provided as user-base kwargs
# as hamilton is not OK with these output names
# We need this as we need to know the name of the function
# for the `@inject` usage but its not provided at
# construction time, so we provide a placeholder
PLACEHOLDER_PARAM_NAME = "__<function_name>"
def __init__(
self,
**parametrization: Union[
Dict[str, ParametrizedDependency],
Tuple[Dict[str, ParametrizedDependency], str],
],
):
"""Decorator to use to create many functions.
:param parametrization: `**kwargs` with one of two things:
- a tuple of assignments (consisting of literals/upstream specifications), and docstring.
- just assignments, in which case it parametrizes the existing docstring.
"""
self.parameterization = {
key: (value[0] if isinstance(value, tuple) else value)
for key, value in parametrization.items()
}
bad_values = []
for _assigned_output, mapping in self.parameterization.items():
for _parameter, val in mapping.items():
if not isinstance(val, ParametrizedDependency):
bad_values.append(val)
if bad_values:
raise base.InvalidDecoratorException(
f"@parameterize must specify a dependency type -- either source() or value()."
f"The following are not allowed: {bad_values}."
)
self.specified_docstrings = {
key: value[1] for key, value in parametrization.items() if isinstance(value, tuple)
}
def split_parameterizations(
self, parameterizations: Dict[str, ParametrizedDependency]
) -> Dict[ParametrizedDependencySource, Dict[str, ParametrizedDependency]]:
"""Split parameterizations into two groups: those that are literal values, and those that are upstream nodes.
Will have a key for each existing dependency type.
:param parameterizations: Passed into @parameterize
:return: The parameterizations grouped by dependency type
"""
out = collections.defaultdict(dict)
for param_name, replacement in parameterizations.items():
out[replacement.get_dependency_type()][param_name] = replacement
return out
def _get_grouped_list_name(self, index: int, arg_name: str):
"""Gets the name of the arg for a given index in a list of args, using grouped"""
return f"__{arg_name}_{index}"
def expand_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
nodes = []
for (
output_node,
parametrization_with_optional_docstring,
) in self.parameterization.items():
if output_node == parameterize.PLACEHOLDER_PARAM_NAME:
output_node = node_.name
if isinstance(
parametrization_with_optional_docstring, tuple
): # In this case it contains the docstring
(parameterization,) = parametrization_with_optional_docstring
else:
parameterization = parametrization_with_optional_docstring
docstring = self.format_doc_string(fn, output_node)
parameterization_splits = self.split_parameterizations(parameterization)
upstream_dependencies = parameterization_splits[ParametrizedDependencySource.UPSTREAM]
literal_dependencies = parameterization_splits[ParametrizedDependencySource.LITERAL]
grouped_list_dependencies = parameterization_splits[
ParametrizedDependencySource.GROUPED_LIST
]
grouped_dict_dependencies = parameterization_splits[
ParametrizedDependencySource.GROUPED_DICT
]
def replacement_function(
*args,
upstream_dependencies=upstream_dependencies,
literal_dependencies=literal_dependencies,
grouped_list_dependencies=grouped_list_dependencies,
grouped_dict_dependencies=grouped_dict_dependencies,
former_inputs=list(node_.input_types.keys()), # noqa
**kwargs,
):
"""This function rewrites what is passed in kwargs to the right kwarg for the function.
The passed in kwargs are all the dependencies of this node. Note that we actually have the "former inputs",
which are what the node declares as its dependencies. So, we just have to loop through all of them to
get the "new" value. This "new" value comes from the parameterization.
Note that much of this code should *probably* live within the source/value/grouped functions, but
it is here as we're not 100% sure about the abstraction.
TODO -- think about how the grouped/source/literal functions should be able to grab the values from kwargs/args.
Should be easy -- they should just have something like a "resolve(**kwargs)" function that they can call.
"""
new_kwargs = {}
for node_input in former_inputs:
if node_input in upstream_dependencies:
# If the node is specified by `source`, then we get the value from the kwargs
new_kwargs[node_input] = kwargs[upstream_dependencies[node_input].source]
elif node_input in literal_dependencies:
# If the node is specified by `value`, then we get the literal value (no need for kwargs)
new_kwargs[node_input] = literal_dependencies[node_input].value
elif node_input in grouped_list_dependencies:
# If the node is specified by `group`, then we get the list of values from the kwargs or the literal
new_kwargs[node_input] = []
for replacement in grouped_list_dependencies[node_input].sources:
resolved_value = (
kwargs[replacement.source]
if replacement.get_dependency_type()
== ParametrizedDependencySource.UPSTREAM
else replacement.value
)
new_kwargs[node_input].append(resolved_value)
elif node_input in grouped_dict_dependencies:
# If the node is specified by `group`, then we get the dict of values from the kwargs or the literal
new_kwargs[node_input] = {}
for dependency, replacement in grouped_dict_dependencies[
node_input
].sources.items():
resolved_value = (
kwargs[replacement.source]
if replacement.get_dependency_type()
== ParametrizedDependencySource.UPSTREAM
else replacement.value
)
new_kwargs[node_input][dependency] = resolved_value
elif node_input in kwargs:
new_kwargs[node_input] = kwargs[node_input]
# This case is left blank for optional parameters. If we error here, we'll break
# the (supported) case of optionals. We do know whether its optional but for
# now the error will be clear enough
return node_.callable(*args, **new_kwargs)
new_input_types = {}
grouped_dependencies = {
**grouped_list_dependencies,
**grouped_dict_dependencies,
}
for param, val in node_.input_types.items():
if param in upstream_dependencies:
new_input_types[upstream_dependencies[param].source] = (
val # We replace with the upstream_dependencies
)
elif param in grouped_dependencies:
# These are the components of the individual sequence
# E.G. if the parameter is List[int], the individual type is just int
grouped_dependency_spec = grouped_dependencies[param]
sequence_component_type = grouped_dependency_spec.resolve_dependency_type(
val[0], param
)
unpacked_dependencies = (
grouped_dependency_spec.sources
if grouped_dependency_spec.get_dependency_type()
== ParametrizedDependencySource.GROUPED_LIST
else grouped_dependency_spec.sources.values()
)
for dep in unpacked_dependencies:
if dep.get_dependency_type() == ParametrizedDependencySource.UPSTREAM:
# TODO -- think through what happens if we have optional pieces...
# I think that we shouldn't allow it...
new_input_types[dep.source] = (
sequence_component_type,
val[1],
)
elif param not in literal_dependencies:
new_input_types[param] = (
val # We just use the standard one, nothing is getting replaced
)
nodes.append(
node_.copy_with(
name=output_node,
doc_string=docstring, # TODO -- change docstring
callabl=functools.partial(
replacement_function,
**{parameter: val.value for parameter, val in literal_dependencies.items()},
),
input_types=new_input_types,
include_refs=False, # Include refs is here as this is earlier than compile time
# TODO -- figure out why this isn't getting replaced later...
)
)
return nodes
def validate(self, fn: Callable):
signature = inspect.signature(fn)
func_param_names = set(signature.parameters.keys())
try:
for output_name, _mappings in self.parameterization.items():
# TODO -- separate out into the two dependency-types
if output_name == self.PLACEHOLDER_PARAM_NAME:
output_name = fn.__name__
self.format_doc_string(fn, output_name)
except KeyError as e:
raise base.InvalidDecoratorException(
f"Function docstring templating is incorrect. "
f"Please fix up the docstring {fn.__module__}.{fn.__name__}."
) from e
if self.RESERVED_KWARG in func_param_names:
raise base.InvalidDecoratorException(
f"Error function {fn.__module__}.{fn.__name__} cannot have '{self.RESERVED_KWARG}'"
f"as a parameter it is reserved."
)
missing_parameters = set()
for mapping in self.parameterization.values():
for param_to_replace in mapping:
if param_to_replace not in func_param_names:
missing_parameters.add(param_to_replace)
if missing_parameters:
raise base.InvalidDecoratorException(
f"Parametrization is invalid: the following parameters don't appear in the function itself: {', '.join(missing_parameters)}"
)
type_hints = typing.get_type_hints(fn)
for _output_name, mapping in self.parameterization.items():
# TODO -- look a the origin type and determine that its a sequence
# We can just use the GroupedListDependency to do this
invalid_types = []
if isinstance(mapping, tuple):
mapping = mapping[0]
for param, replacement_value in mapping.items():
param_annotation = type_hints[param]
if typing_inspect.is_optional_type(param_annotation):
param_annotation = typing_inspect.get_args(param_annotation)[0]
is_generic = typing_inspect.is_generic_type(param_annotation)
if (
replacement_value.get_dependency_type()
== ParametrizedDependencySource.GROUPED_LIST
):
if not is_generic:
invalid_types.append((param, param_annotation))
else:
origin = typing_inspect.get_origin(param_annotation)
if origin != list:
invalid_types.append((param, param_annotation))
# 3.9 + this works
# 3.8 they changed it, so it gives false positives, but we're OK not fixing
# for older versions of python
args = typing_inspect.get_args(param_annotation)
if not len(args) == 1:
invalid_types.append((param, param_annotation))
elif (
replacement_value.get_dependency_type()
== ParametrizedDependencySource.GROUPED_DICT
):
if not is_generic:
invalid_types.append((param, param_annotation))
else:
origin = typing_inspect.get_origin(param_annotation)
if origin != dict:
invalid_types.append((param, param_annotation))
args = typing_inspect.get_args(param_annotation)
if not len(args) == 2:
invalid_types.append((param, param_annotation))
elif args[0] != str:
invalid_types.append((param, param_annotation))
if invalid_types:
raise base.InvalidDecoratorException(
f"Validation for fn: {fn.__qualname__} All parameters with a group() parameterization must be annotated as a list: "
f"the following are not: {', '.join([f'{param} ({annotation})' for param, annotation in invalid_types])}"
)
def format_doc_string(self, fn: Callable, output_name: str) -> str:
"""Helper function to format a function documentation string.
:param doc: the string template to format
:param output_name: the output name of the function
:param params: the parameter mappings
:return: formatted string
:raises: KeyError if there is a template variable missing from the parameter mapping.
"""
class IdentityDict(dict):
# quick hack to allow for formatting of missing parameters
def __missing__(self, key):
return key
if output_name in self.specified_docstrings:
return self.specified_docstrings[output_name]
doc = fn.__doc__
if doc is None:
return None
parameterizations = self.parameterization.copy()
if self.PLACEHOLDER_PARAM_NAME in parameterizations:
parameterizations[fn.__name__] = parameterizations.pop(self.PLACEHOLDER_PARAM_NAME)
parametrization = parameterizations[output_name]
upstream_dependencies = {
parameter: replacement.source
for parameter, replacement in parametrization.items()
if replacement.get_dependency_type() == ParametrizedDependencySource.UPSTREAM
}
literal_dependencies = {
parameter: replacement.value
for parameter, replacement in parametrization.items()
if replacement.get_dependency_type() == ParametrizedDependencySource.LITERAL
}
return doc.format_map(
IdentityDict(
**{self.RESERVED_KWARG: output_name},
**{**upstream_dependencies, **literal_dependencies},
)
)
class parameterize_values(parameterize):
"""Expands a single function into n, each of which corresponds to a function in which the parameter value is \
replaced by that `specific value`.
.. code-block:: python
import pandas as pd
from hamilton.function_modifiers import parameterize_values
import internal_package_with_logic
ONE_OFF_DATES = {
#output name # doc string # input value to function
('D_ELECTION_2016', 'US Election 2016 Dummy'): '2016-11-12',
('SOME_OUTPUT_NAME', 'Doc string for this thing'): 'value to pass to function',
}
# parameter matches the name of the argument in the function below
@parameterize_values(parameter='one_off_date', assigned_output=ONE_OFF_DATES)
def create_one_off_dates(date_index: pd.Series, one_off_date: str) -> pd.Series:
'''Given a date index, produces a series where a 1 is placed at the date index that would contain that event.'''
one_off_dates = internal_package_with_logic.get_business_week(one_off_date)
return internal_package_with_logic.bool_to_int(date_index.isin([one_off_dates]))
"""
def __init__(self, parameter: str, assigned_output: Dict[Tuple[str, str], Any]):
"""Constructor for a modifier that expands a single function into n, each of which
corresponds to a function in which the parameter value is replaced by that *specific value*.
:param parameter: Parameter to expand on.
:param assigned_output: A map of tuple of [parameter names, documentation] to values
"""
for node_ in assigned_output.keys():
if not isinstance(node_, Tuple):
raise base.InvalidDecoratorException(
f"assigned_output key is incorrect: {node_}. The parameterized decorator needs a dict of "
"[name, doc string] -> value to function."
)
super(parameterize_values, self).__init__(
**{
output: ({parameter: value(literal_value)}, documentation)
for (output, documentation), literal_value in assigned_output.items()
}
)
@deprecation.deprecated(
warn_starting=(1, 10, 0),
fail_starting=(2, 0, 0),
use_this=parameterize_values,
explanation="We now support three parametrize decorators. @parameterize, @parameterize_values, and @parameterize_inputs",
migration_guide="https://github.com/dagworks-inc/hamilton/blob/main/decorators.md#migrating-parameterized",
)
class parametrized(parameterize_values):
pass
class parameterize_sources(parameterize):
"""Expands a single function into `n`, each of which corresponds to a function in which the parameters specified \
are mapped to the specified inputs. Note this decorator and ``@parameterize_values`` are quite similar, except \
that the input here is another DAG node(s), i.e. column/input, rather than a specific scalar/static value.
.. code-block:: python
import pandas as pd
from hamilton.function_modifiers import parameterize_sources
@parameterize_sources(
D_ELECTION_2016_shifted=dict(one_off_date='D_ELECTION_2016'),
SOME_OUTPUT_NAME=dict(one_off_date='SOME_INPUT_NAME')
)
def date_shifter(one_off_date: pd.Series) -> pd.Series:
'''{one_off_date} shifted by 1 to create {output_name}'''
return one_off_date.shift(1)
"""
def __init__(self, **parameterization: Dict[str, str]):
"""Constructor for a modifier that expands a single function into n, each of which corresponds to replacing\
some subset of the specified parameters with specific upstream nodes.
Note this decorator and `@parametrized_input` are similar, except this one allows multiple \
parameters to be mapped to multiple function arguments (and it fixes the spelling mistake).
`parameterized_sources` allows you keep your code DRY by reusing the same function but replace the inputs \
to create multiple corresponding distinct outputs. We see here that `parameterized_inputs` allows you to keep \
your code DRY by reusing the same function to create multiple distinct outputs. The key word arguments passed \
have to have the following structure:
> OUTPUT_NAME = Mapping of function argument to input that should go into it.
The documentation for the output is taken from the function. The documentation string can be templatized with\
the parameter names of the function and the reserved value `output_name` - those will be replaced with the\
corresponding values from the parameterization.
:param \\*\\*parameterization: kwargs of output name to dict of parameter mappings.
"""
self.parametrization = parameterization
if not parameterization:
raise ValueError("Cannot pass empty/None dictionary to parameterize_sources")
for output, mappings in parameterization.items():
if not mappings:
raise ValueError(
f"Error, {output} has a none/empty dictionary mapping. Please fill it."
)
super(parameterize_sources, self).__init__(
**{
output: {
parameter: source(upstream_node) for parameter, upstream_node in mapping.items()
}
for output, mapping in parameterization.items()
}
)
@deprecation.deprecated(
warn_starting=(1, 10, 0),
fail_starting=(2, 0, 0),
use_this=parameterize_sources,
explanation="We now support three parametrize decorators. @parameterize, "
"@parameterize_values, and @parameterize_inputs",
migration_guide="https://github.com/dagworks-inc/hamilton/blob/main/decorators.md#migrating"
"-parameterized",
)
class parametrized_input(parameterize):
def __init__(self, parameter: str, variable_inputs: Dict[str, Tuple[str, str]]):
"""Constructor for a modifier that expands a single function into n, each of which
corresponds to the specified parameter replaced by a *specific input column*.
Note this decorator and `@parametrized` are quite similar, except that the input here is another DAG node,
i.e. column, rather than some specific value.
The `parameterized_input` allows you keep your code DRY by reusing the same function but replace the inputs
to create multiple corresponding distinct outputs. The _parameter_ key word argument has to match one of the
arguments in the function. The rest of the arguments are pulled from items inside the DAG.
The _assigned_inputs_ key word argument takes in a dictionary of \
input_column -> tuple(Output Name, Documentation string).
:param parameter: Parameter to expand on.
:param variable_inputs: A map of tuple of [parameter names, documentation] to values
"""
for val in variable_inputs.values():
if not isinstance(val, Tuple):
raise base.InvalidDecoratorException(
f"assigned_output key is incorrect: {node}. The parameterized decorator needs a dict of "
"input column -> [name, description] to function."
)
super(parametrized_input, self).__init__(
**{
output: ({parameter: source(value)}, documentation)
for value, (output, documentation) in variable_inputs.items()
}
)
@deprecation.deprecated(
warn_starting=(1, 10, 0),
fail_starting=(2, 0, 0),
use_this=parameterize_sources,
explanation="We now support three parametrize decorators. @parameterize, @parameterize_values, and @parameterize_inputs",
migration_guide="https://github.com/dagworks-inc/hamilton/blob/main/decorators.md#migrating-parameterized",
)
class parameterized_inputs(parameterize_sources):
pass
class extract_columns(base.SingleNodeNodeTransformer):
def __init__(self, *columns: Union[Tuple[str, str], str], fill_with: Any = None):
"""Constructor for a modifier that expands a single function into the following nodes:
- n functions, each of which take in the original dataframe and output a specific column
- 1 function that outputs the original dataframe
:param columns: Columns to extract, that can be a list of tuples of (name, documentation) or just names.
:param fill_with: If you want to extract a column that doesn't exist, do you want to fill it with a default \
value? Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a \
column.
"""
super(extract_columns, self).__init__()
if not columns:
raise base.InvalidDecoratorException(
"Error empty arguments passed to extract_columns decorator."
)
elif isinstance(columns[0], list):
raise base.InvalidDecoratorException(
"Error list passed in. Please `*` in front of it to expand it."
)
self.columns = columns
self.fill_with = fill_with
@staticmethod
def validate_return_type(fn: Callable):
"""Validates that the return type of the function is a pandas dataframe.
:param fn: Function to validate
"""
output_type = typing.get_type_hints(fn).get("return")
try:
registry.get_column_type_from_df_type(output_type)
except NotImplementedError as e:
raise base.InvalidDecoratorException(
# TODO: capture was dataframe libraries are supported and print here.
f"Error {fn} does not output a type we know about. Is it a dataframe type we "
f"support? "
) from e
def validate(self, fn: Callable):
"""A function is invalid if it does not output a dataframe.
:param fn: Function to validate.
:raises: InvalidDecoratorException If the function does not output a Dataframe
"""
extract_columns.validate_return_type(fn)
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
"""For each column to extract, output a node that extracts that column. Also, output the original dataframe
generator.
:param node_: Node to transform
:param config: Config to use
:param fn: Function to extract columns from. Must output a dataframe.
:return: A collection of nodes --
one for the original dataframe generator, and another for each column to extract.
"""
fn = node_.callable
base_doc = node_.documentation
# if fn is an async function
if inspect.iscoroutinefunction(fn):
async def df_generator(*args, **kwargs) -> Any:
df_generated = await fn(*args, **kwargs)
if self.fill_with is not None:
for col in self.columns:
if col not in df_generated:
registry.fill_with_scalar(df_generated, col, self.fill_with)
assert col in df_generated
return df_generated
else:
def df_generator(*args, **kwargs) -> Any:
df_generated = fn(*args, **kwargs)
if self.fill_with is not None:
for col in self.columns:
if col not in df_generated:
registry.fill_with_scalar(df_generated, col, self.fill_with)
assert col in df_generated
return df_generated
output_nodes = [node_.copy_with(callabl=df_generator)]
output_type = node_.type
series_type = registry.get_column_type_from_df_type(output_type)
for column in self.columns:
doc_string = base_doc # default doc string of base function.
if isinstance(column, Tuple): # Expand tuple into constituents
column, doc_string = column
if inspect.iscoroutinefunction(fn):
async def extractor_fn(column_to_extract: str = column, **kwargs) -> Any:
df = kwargs[node_.name]
if column_to_extract not in df:
raise base.InvalidDecoratorException(
f"No such column: {column_to_extract} produced by {node_.name}. "
f"It only produced {str(df.columns)}"
)
return registry.get_column(df, column_to_extract)
else:
def extractor_fn(
column_to_extract: str = column, **kwargs
) -> Any: # avoiding problems with closures
df = kwargs[node_.name]
if column_to_extract not in df:
raise base.InvalidDecoratorException(
f"No such column: {column_to_extract} produced by {node_.name}. "
f"It only produced {str(df.columns)}"
)
return registry.get_column(df, column_to_extract)
output_nodes.append(
node.Node(
column,
series_type,
doc_string,
extractor_fn,
input_types={node_.name: output_type},
tags=node_.tags.copy(),
)
)
return output_nodes
def _validate_extract_fields(fields: dict):
"""Validates the fields dict for extract field.
Rules are:
- All keys must be strings
- All values must be types
- It must not be empty
:param fields: Constructor argument to extract_fields
:raises InvalidDecoratorException: If the fields dict is invalid.
"""
if not fields:
raise base.InvalidDecoratorException(
"Error an empty dict, or no dict, passed to extract_fields decorator."
)
elif not isinstance(fields, dict):
raise base.InvalidDecoratorException(f"Error, please pass in a dict, not {type(fields)}")
else:
errors = []
for field, field_type in fields.items():
if not isinstance(field, str):
errors.append(f"{field} is not a string. All keys must be strings.")
# second condition needed because isinstance(Any, type) == False for Python <3.11
if not (
isinstance(field_type, type)
or field_type is Any
or typing_inspect.is_generic_type(field_type)
or typing_inspect.is_union_type(field_type)
):
errors.append(f"{field} does not declare a type. Instead it passes {field_type}.")
if errors:
raise base.InvalidDecoratorException(
f"Error, found these {errors}. " f"Please pass in a dict of string to types. "
)
class extract_fields(base.SingleNodeNodeTransformer):
"""Extracts fields from a dictionary of output."""
def __init__(self, fields: dict = None, fill_with: Any = None):
"""Constructor for a modifier that expands a single function into the following nodes:
- n functions, each of which take in the original dict and output a specific field
- 1 function that outputs the original dict
:param fields: Fields to extract. A dict of 'field_name' -> 'field_type'.
:param fill_with: If you want to extract a field that doesn't exist, do you want to fill it with a default \
value? Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a \
field value.
"""
super(extract_fields, self).__init__()
self.fields = fields
self.fill_with = fill_with
def validate(self, fn: Callable):
"""A function is invalid if it is not annotated with a dict or typing.Dict return type.
:param fn: Function to validate.
:raises: InvalidDecoratorException If the function is not annotated with a dict or typing.Dict type as output.
"""
output_type = typing.get_type_hints(fn).get("return")
if typing_inspect.is_generic_type(output_type):
base_type = typing_inspect.get_origin(output_type)
if base_type == dict or base_type == Dict:
_validate_extract_fields(self.fields)
else:
raise base.InvalidDecoratorException(
f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}"
)
elif output_type == dict:
_validate_extract_fields(self.fields)
elif typing_extensions.is_typeddict(output_type):
if self.fields is None:
self.fields = typing.get_type_hints(output_type)
else:
# check that fields is a subset of TypedDict that is defined
typed_dict_fields = typing.get_type_hints(output_type)
for field_name, field_type in self.fields.items():
expected_type = typed_dict_fields.get(field_name, None)
if expected_type == field_type:
pass # we're definitely good
elif expected_type is not None and htypes.custom_subclass_check(
field_type, expected_type
):
pass
else:
raise base.InvalidDecoratorException(
f"Error {self.fields} did not match a subset of the TypedDict annotation's fields {typed_dict_fields}."
)
_validate_extract_fields(self.fields)
else:
raise base.InvalidDecoratorException(
f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}"
)
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
"""For each field to extract, output a node that extracts that field. Also, output the original TypedDict
generator.
:param node_:
:param config:
:param fn: Function to extract columns from. Must output a dataframe.
:return: A collection of nodes --
one for the original dataframe generator, and another for each column to extract.
"""
fn = node_.callable
base_doc = node_.documentation
# if fn is async
if inspect.iscoroutinefunction(fn):
async def dict_generator(*args, **kwargs):
dict_generated = await fn(*args, **kwargs)
if self.fill_with is not None:
for field in self.fields:
if field not in dict_generated:
dict_generated[field] = self.fill_with
return dict_generated
else:
def dict_generator(*args, **kwargs):
dict_generated = fn(*args, **kwargs)
if self.fill_with is not None:
for field in self.fields:
if field not in dict_generated:
dict_generated[field] = self.fill_with
return dict_generated
output_nodes = [node_.copy_with(callabl=dict_generator)]
for field, field_type in self.fields.items():
doc_string = base_doc # default doc string of base function.
# if fn is async
if inspect.iscoroutinefunction(fn):
async def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type:
dt = kwargs[node_.name]
if field_to_extract not in dt:
raise base.InvalidDecoratorException(
f"No such field: {field_to_extract} produced by {node_.name}. "
f"It only produced {list(dt.keys())}"
)
return kwargs[node_.name][field_to_extract]
else:
def extractor_fn(
field_to_extract: str = field, **kwargs
) -> field_type: # avoiding problems with closures
dt = kwargs[node_.name]
if field_to_extract not in dt:
raise base.InvalidDecoratorException(
f"No such field: {field_to_extract} produced by {node_.name}. "
f"It only produced {list(dt.keys())}"
)
return kwargs[node_.name][field_to_extract]
output_nodes.append(
node.Node(
field,
field_type,
doc_string,
extractor_fn,
input_types={node_.name: dict},
tags=node_.tags.copy(),
)
)
return output_nodes
@dataclasses.dataclass
class ParameterizedExtract:
"""Dataclass to hold inputs for @parameterize and @parameterize_extract_columns.
:param outputs: A tuple of strings, each of which is the name of an output.
:param input_mapping: A dictionary of string to ParametrizedDependency. The string is the name of the python \
parameter of the decorated function, and the value is a "source"/"value" which will be passed as input for that\
parameter to the function.
"""
outputs: Tuple[str, ...]
input_mapping: Dict[str, ParametrizedDependency]
class parameterize_extract_columns(base.NodeExpander):
"""`@parameterize_extract_columns` gives you the power of both `@extract_columns` and `@parameterize` in one\
decorator.
It takes in a list of `Parameterized_Extract` objects, each of which is composed of:
1. A list of columns to extract, and
2. A parameterization that gets used
In the following case, we produce four columns, two for each parameterization:
.. code-block:: python
import pandas as pd
from function_modifiers import parameterize_extract_columns, ParameterizedExtract, source, value
@parameterize_extract_columns(
ParameterizedExtract(
("outseries1a", "outseries2a"),
{"input1": source("inseries1a"), "input2": source("inseries1b"), "input3": value(10)},
),
ParameterizedExtract(
("outseries1b", "outseries2b"),
{"input1": source("inseries2a"), "input2": source("inseries2b"), "input3": value(100)},
),
)
def fn(input1: pd.Series, input2: pd.Series, input3: float) -> pd.DataFrame:
return pd.concat([input1 * input2 * input3, input1 + input2 + input3], axis=1)
"""
def __init__(self, *extract_config: ParameterizedExtract, reassign_columns: bool = True):
"""Initializes a `parameterized_extract` decorator. Note this currently works for series,
but the plan is to extend it to fields as well...
:param extract_config: A configuration consisting of a list ParameterizedExtract classes\
These contain the information of a `@parameterized` and `@extract...` together.
:param reassign_columns: Whether we want to reassign the columns as part of the function.
"""
self.extract_config = extract_config
self.reassign_columns = reassign_columns
def expand_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
"""Expands a node into multiple, given the extract_config passed to
parameterize_extract_columns. Goes through all parameterizations,
creates an extract_columns node for each, then delegates to that.
Note this calls out to `@parameterize` and `@extract_columns` rather
than reimplementing the logic.
:param node_: Node to expand
:param config: Config to use to expand
:param fn: Original function
:return: The nodes produced by this decorator.
"""
output_nodes = []
for i, parameterization in enumerate(self.extract_config):
@functools.wraps(fn)
def wrapper_fn(*args, _output_columns=parameterization.outputs, **kwargs):
df_out = fn(*args, **kwargs)
df_out.columns = _output_columns
return df_out
new_node = node_.copy_with(callabl=wrapper_fn)
fn_to_call = wrapper_fn if self.reassign_columns else fn
# We have to rename the underlying function so that we do not
# get naming collisions. Using __ is cleaner than using a uuid
# as it is easier to read/manage and naturally maeks sense.
parameterization_decorator = parameterize(
**{node_.name + f"__{i}": parameterization.input_mapping}
)
(parameterized_node,) = parameterization_decorator.expand_node(
new_node, config, fn_to_call
)
extract_columns_decorator = extract_columns(*parameterization.outputs)
output_nodes.extend(
extract_columns_decorator.transform_node(
parameterized_node, config, parameterized_node.callable
)
)
return output_nodes
def validate(self, fn: Callable):
extract_columns.validate_return_type(fn)
class inject(parameterize):
"""@inject allows you to replace parameters with values passed in. You can think of
it as a `@parameterize` call that has only one parameterization, the result of which
is the name of the function. See the following examples:
.. code-block:: python
import pandas as pd
from function_modifiers import inject, source, value, group
@inject(nums=group(source('a'), value(10), source('b'), value(2)))
def a_plus_10_plus_b_plus_2(nums: List[int]) -> int:
return sum(nums)
This would be equivalent to:
@parameterize(
a_plus_10_plus_b_plus_2={
'nums': group(source('a'), value(10), source('b'), value(2))
})
def sum_numbers(nums: List[int]) -> int:
return sum(nums)
Something to note -- we currently do not support the case in which the same parameter is utilized
multiple times as an injection. E.G. two lists, a list and a dict, two sources, etc...
This is considered undefined behavior, and should be avoided.
"""
def __init__(self, **key_mapping: ParametrizedDependency):
"""Instantiates an @inject decorator with the given key_mapping.
:param key_mapping: A dictionary of string to dependency spec.
This is the same as the input mapping in `@parameterize`.
"""
super(inject, self).__init__(**{parameterize.PLACEHOLDER_PARAM_NAME: key_mapping})