blob: ed5b3463e276bc1d90907e81d54eba55341a1dc7 [file] [log] [blame]
import sys
from typing import Any, Dict, List, Optional, Type
import numpy as np
import pandas as pd
import pytest
import hamilton.function_modifiers
from hamilton import function_modifiers, node
from hamilton.function_modifiers import base, expanders
from hamilton.function_modifiers.dependencies import (
GroupedDependency,
GroupedDictDependency,
GroupedListDependency,
group,
source,
value,
)
from hamilton.htypes import Collect, Parallelizable
from hamilton.node import DependencyType
def test_parametrized_invalid_params():
annotation = function_modifiers.parameterize_values(
parameter="non_existant",
assigned_output={("invalid_node_name", "invalid_doc"): "invalid_value"},
)
def no_param_node():
pass
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
annotation.validate(no_param_node)
def wrong_param_node(valid_value):
pass
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
annotation.validate(wrong_param_node)
def test_parametrized_single_param_breaks_without_docs():
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
function_modifiers.parameterize_values(
parameter="parameter", assigned_output={"only_node_name": "only_value"}
)
def test_parametrized_single_param():
annotation = function_modifiers.parameterize_values(
parameter="parameter", assigned_output={("only_node_name", "only_doc"): "only_value"}
)
def identity(parameter: Any) -> Any:
return parameter
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 1
assert nodes[0].name == "only_node_name"
assert nodes[0].type == Any
assert nodes[0].documentation == "only_doc"
called = nodes[0].callable()
assert called == "only_value"
def test_parametrized_single_param_expanded():
annotation = function_modifiers.parameterize_values(
parameter="parameter",
assigned_output={("node_name_1", "doc1"): "value_1", ("node_value_2", "doc2"): "value_2"},
)
def identity(parameter: Any) -> Any:
return parameter
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
called_1 = nodes[0].callable()
called_2 = nodes[1].callable()
assert nodes[0].documentation == "doc1"
assert nodes[1].documentation == "doc2"
assert called_1 == "value_1"
assert called_2 == "value_2"
def test_parametrized_with_multiple_params():
annotation = function_modifiers.parameterize_values(
parameter="parameter",
assigned_output={("node_name_1", "doc1"): "value_1", ("node_value_2", "doc2"): "value_2"},
)
def identity(parameter: Any, static: Any) -> Any:
return parameter, static
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
called_1 = nodes[0].callable(static="static_param")
called_2 = nodes[1].callable(static="static_param")
assert called_1 == ("value_1", "static_param")
assert called_2 == ("value_2", "static_param")
def test_parametrized_input():
annotation = function_modifiers.parametrized_input(
parameter="parameter",
variable_inputs={
"input_1": ("test_1", "Function with first column as input"),
"input_2": ("test_2", "Function with second column as input"),
},
)
def identity(parameter: Any, static: Any) -> Any:
return parameter, static
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
nodes = sorted(nodes, key=lambda n: n.name)
assert [n.name for n in nodes] == ["test_1", "test_2"]
assert set(nodes[0].input_types.keys()) == {"static", "input_1"}
assert set(nodes[1].input_types.keys()) == {"static", "input_2"}
def test_parametrize_sources_validate_param_name():
"""Tests validate function of parameterize_sources capturing bad param name usage."""
annotation = function_modifiers.parameterize_sources(
parameterization={
"test_1": dict(parameterfoo="input_1"),
}
)
def identity(parameter1: str, parameter2: str, static: str) -> str:
"""Function with {parameter1} as first input"""
return parameter1 + parameter2 + static
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
annotation.validate(identity)
def test_parametrized_inputs_validate_reserved_param():
"""Tests validate function of parameterize_inputs catching reserved param usage."""
annotation = function_modifiers.parameterize_sources(
**{
"test_1": dict(parameter2="input_1"),
}
)
def identity(output_name: str, parameter2: str, static: str) -> str:
"""Function with {parameter2} as second input"""
return output_name + parameter2 + static
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
annotation.validate(identity)
def test_parametrized_inputs_validate_bad_doc_string():
"""Tests validate function of parameterize_inputs catching bad doc string."""
annotation = function_modifiers.parameterize_sources(
**{
"test_1": dict(parameter2="input_1"),
}
)
def identity(output_name: str, parameter2: str, static: str) -> str:
"""Function with {foo} as second input"""
return output_name + parameter2 + static
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
annotation.validate(identity)
def test_parametrized_inputs():
annotation = function_modifiers.parameterize_sources(
**{
"test_1": dict(parameter1="input_1", parameter2="input_2"),
"test_2": dict(parameter1="input_2", parameter2="input_1"),
}
)
def identity(parameter1: str, parameter2: str, static: str) -> str:
"""Function with {parameter1} as first input"""
return parameter1 + parameter2 + static
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
nodes = sorted(nodes, key=lambda n: n.name)
assert [n.name for n in nodes] == ["test_1", "test_2"]
assert set(nodes[0].input_types.keys()) == {"static", "input_1", "input_2"}
assert nodes[0].documentation == "Function with input_1 as first input"
assert set(nodes[1].input_types.keys()) == {"static", "input_1", "input_2"}
assert nodes[1].documentation == "Function with input_2 as first input"
result1 = nodes[0].callable(**{"input_1": "1", "input_2": "2", "static": "3"})
assert result1 == "123"
result2 = nodes[1].callable(**{"input_1": "1", "input_2": "2", "static": "3"})
assert result2 == "213"
def test_invalid_column_extractor():
annotation = function_modifiers.extract_columns("dummy_column")
def no_param_node() -> int:
pass
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
annotation.validate(no_param_node)
def test_extract_columns_invalid_passing_list_to_column_extractor():
"""Ensures that people cannot pass in a list."""
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
function_modifiers.extract_columns(["a", "b", "c"])
def test_extract_columns_empty_args():
"""Tests that we fail on empty arguments."""
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
function_modifiers.extract_columns()
def test_extract_columns_happy():
"""Tests that we are happy with good arguments."""
function_modifiers.extract_columns(*["a", ("b", "some doc"), "c"])
def test_valid_column_extractor():
"""Tests that things work, and that you can provide optional documentation."""
annotation = function_modifiers.extract_columns("col_1", ("col_2", "col2_doc"))
def dummy_df_generator() -> pd.DataFrame:
"""dummy doc"""
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
nodes = list(
annotation.transform_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator)
)
assert len(nodes) == 3
assert nodes[0] == node.Node(
name=dummy_df_generator.__name__,
typ=pd.DataFrame,
doc_string=dummy_df_generator.__doc__,
callabl=dummy_df_generator,
tags={"module": "tests.function_modifiers.test_expanders"},
)
assert nodes[1].name == "col_1"
assert nodes[1].type == pd.Series
assert nodes[1].documentation == "dummy doc" # we default to base function doc.
assert nodes[1].input_types == {
dummy_df_generator.__name__: (pd.DataFrame, DependencyType.REQUIRED)
}
assert nodes[2].name == "col_2"
assert nodes[2].type == pd.Series
assert nodes[2].documentation == "col2_doc"
assert nodes[2].input_types == {
dummy_df_generator.__name__: (pd.DataFrame, DependencyType.REQUIRED)
}
def test_column_extractor_fill_with():
def dummy_df() -> pd.DataFrame:
"""dummy doc"""
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
annotation = function_modifiers.extract_columns("col_3", fill_with=0)
original_node, extracted_column_node = annotation.transform_node(
node.Node.from_fn(dummy_df), {}, dummy_df
)
original_df = original_node.callable()
extracted_column = extracted_column_node.callable(dummy_df=original_df)
pd.testing.assert_series_equal(extracted_column, pd.Series([0, 0, 0, 0]), check_names=False)
pd.testing.assert_series_equal(
original_df["col_3"], pd.Series([0, 0, 0, 0]), check_names=False
) # it has to be in there now
def test_column_extractor_no_fill_with():
def dummy_df_generator() -> pd.DataFrame:
"""dummy doc"""
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
annotation = function_modifiers.extract_columns("col_3")
nodes = list(
annotation.transform_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator)
)
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
nodes[1].callable(dummy_df_generator=dummy_df_generator())
@pytest.mark.parametrize(
"fields",
[
None, # empty
"string_input", # not a dict
["string_input"], # not a dict
{}, # empty dict
{1: "string", "field": str}, # invalid dict
{"field": lambda x: x, "field2": int}, # invalid dict
],
)
def test_validate_extract_fields(fields):
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
expanders._validate_extract_fields(fields)
@pytest.mark.parametrize(
"fields",
[
{"field": np.ndarray, "field2": str},
{"field": dict, "field2": int, "field3": list, "field4": float, "field5": str},
{"field": List[str]},
{"field": Dict[str, List[str]]},
],
)
def test_extract_fields_constructor_happy(fields):
"""Tests that we are happy with good arguments."""
expanders._validate_extract_fields(fields)
@pytest.mark.parametrize(
"return_type",
[
dict,
Dict,
Dict[str, str],
Dict[str, Any],
],
)
def test_extract_fields_validate_happy(return_type):
def return_dict() -> return_type:
return {}
annotation = function_modifiers.extract_fields({"test": int})
annotation.validate(return_dict)
@pytest.mark.parametrize("return_type", [(int), (list), (np.ndarray), (pd.DataFrame)])
def test_extract_fields_validate_errors(return_type):
def return_dict() -> return_type:
return {}
annotation = function_modifiers.extract_fields({"test": int})
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
annotation.validate(return_dict)
def test_valid_extract_fields():
"""Tests whole extract_fields decorator."""
annotation = function_modifiers.extract_fields(
{"col_1": list, "col_2": int, "col_3": np.ndarray}
)
def dummy_dict_generator() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])}
nodes = list(
annotation.transform_node(node.Node.from_fn(dummy_dict_generator), {}, dummy_dict_generator)
)
assert len(nodes) == 4
assert nodes[0] == node.Node(
name=dummy_dict_generator.__name__,
typ=dict,
doc_string=dummy_dict_generator.__doc__,
callabl=dummy_dict_generator,
tags={"module": "tests.function_modifiers.test_expanders"},
)
assert nodes[1].name == "col_1"
assert nodes[1].type == list
assert nodes[1].documentation == "dummy doc" # we default to base function doc.
assert nodes[1].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)}
assert nodes[2].name == "col_2"
assert nodes[2].type == int
assert nodes[2].documentation == "dummy doc"
assert nodes[2].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)}
assert nodes[3].name == "col_3"
assert nodes[3].type == np.ndarray
assert nodes[3].documentation == "dummy doc"
assert nodes[3].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)}
def test_extract_fields_fill_with():
def dummy_dict() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])}
annotation = function_modifiers.extract_fields({"col_2": int, "col_4": float}, fill_with=1.0)
original_node, extracted_field_node, missing_field_node = annotation.transform_node(
node.Node.from_fn(dummy_dict), {}, dummy_dict
)
original_dict = original_node.callable()
extracted_field = extracted_field_node.callable(dummy_dict=original_dict)
missing_field = missing_field_node.callable(dummy_dict=original_dict)
assert extracted_field == 1
assert missing_field == 1.0
def test_extract_fields_no_fill_with():
def dummy_dict() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])}
annotation = function_modifiers.extract_fields({"col_4": int})
nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict))
with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
nodes[1].callable(dummy_dict=dummy_dict())
def concat(upstream_parameter: str, literal_parameter: str) -> Any:
"""Concatenates {upstream_parameter} with literal_parameter"""
return f"{upstream_parameter}{literal_parameter}"
def test_parametrized_full_no_replacement():
annotation = function_modifiers.parameterize(replace_no_parameters={})
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.callable(upstream_parameter="foo", literal_parameter="bar") == "foobar"
assert node_.input_types == {
"literal_parameter": (str, DependencyType.REQUIRED),
"upstream_parameter": (str, DependencyType.REQUIRED),
}
assert node_.documentation == concat.__doc__.format(upstream_parameter="upstream_parameter")
def test_parametrized_full_replace_just_upstream():
annotation = function_modifiers.parameterize(
replace_just_upstream_parameter={"upstream_parameter": source("foo_source")},
)
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.input_types == {
"literal_parameter": (str, DependencyType.REQUIRED),
"foo_source": (str, DependencyType.REQUIRED),
}
assert node_.callable(foo_source="foo", literal_parameter="bar") == "foobar"
assert node_.documentation == concat.__doc__.format(upstream_parameter="foo_source")
def test_parametrized_full_replace_just_literal():
annotation = function_modifiers.parameterize(
replace_just_literal_parameter={"literal_parameter": value("bar")}
)
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.input_types == {"upstream_parameter": (str, DependencyType.REQUIRED)}
assert node_.callable(upstream_parameter="foo") == "foobar"
assert node_.documentation == concat.__doc__.format(upstream_parameter="upstream_parameter")
def test_parametrized_full_replace_both():
annotation = function_modifiers.parameterize(
replace_both_parameters={
"upstream_parameter": source("foo_source"),
"literal_parameter": value("bar"),
}
)
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.input_types == {"foo_source": (str, DependencyType.REQUIRED)}
assert node_.callable(foo_source="foo") == "foobar"
assert node_.documentation == concat.__doc__.format(upstream_parameter="foo_source")
def test_parametrized_full_multiple_replacements():
args = dict(
replace_no_parameters=({}, "fn with no parameters replaced"),
replace_just_upstream_parameter=(
{"upstream_parameter": source("foo_source")},
"fn with upstream_parameter set to node foo",
),
replace_just_literal_parameter=(
{"literal_parameter": value("bar")},
"fn with upstream_parameter set to node foo",
),
replace_both_parameters=(
{"upstream_parameter": source("foo_source"), "literal_parameter": value("bar")},
"fn with both parameters replaced",
),
)
annotation = function_modifiers.parameterize(**args)
nodes = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert len(nodes) == 4
# test out that documentation is assigned correctly
assert [node_.documentation for node_ in nodes] == [args[node_.name][1] for node_ in nodes]
def test_parameterized_extract_columns():
annotation = function_modifiers.parameterize_extract_columns(
function_modifiers.ParameterizedExtract(
("outseries1a", "outseries2a"),
{"input1": source("inseries1a"), "input2": source("inseries1b"), "input3": value(10)},
),
function_modifiers.ParameterizedExtract(
("outseries1b", "outseries2b"),
{"input1": source("inseries2a"), "input2": source("inseries2b"), "input3": value(100)},
),
)
def fn(input1: pd.Series, input2: pd.Series, input3: float) -> pd.DataFrame:
return pd.concat([input1 * input2 * input3, input1 + input2 + input3], axis=1)
nodes = annotation.expand_node(node.Node.from_fn(fn), {}, fn)
# For each parameterized set, we have two outputs and the dataframe node
assert len(nodes) == 6
nodes_by_name = {node_.name: node_ for node_ in nodes}
# Test that it produces the expected results
pd.testing.assert_frame_equal(
nodes_by_name["fn__0"](inseries1a=pd.Series([1]), inseries1b=pd.Series([1])),
pd.DataFrame.from_dict({"outseries1a": [10], "outseries2a": [12]}),
)
pd.testing.assert_frame_equal(
nodes_by_name["fn__1"](inseries2a=pd.Series([1]), inseries2b=pd.Series([1])),
pd.DataFrame.from_dict({"outseries1b": [100], "outseries2b": [102]}),
)
# test that each of the "extractor" nodes produces exactly what we expect
assert nodes_by_name["outseries1a"](fn__0=pd.DataFrame({"outseries1a": [10]}))[0] == 10
assert nodes_by_name["outseries2a"](fn__0=pd.DataFrame({"outseries2a": [20]}))[0] == 20
assert nodes_by_name["outseries1b"](fn__1=pd.DataFrame({"outseries1b": [30]}))[0] == 30
assert nodes_by_name["outseries2b"](fn__1=pd.DataFrame({"outseries2b": [40]}))[0] == 40
def test_parametrized_full_replace_groups_with_literal():
def add_n(grouped_parameter: List[int]) -> int:
return sum(grouped_parameter)
annotation = function_modifiers.parameterize(
replace_just_literal_parameter={"grouped_parameter": group(value(1), value(2), value(3))}
)
(node_,) = annotation.expand_node(node.Node.from_fn(add_n), {}, concat)
assert node_.input_types == {}
assert node_.callable() == 6
def test_parametrized_full_replace_groups_with_sources():
def add_n(grouped_parameter: List[int]) -> int:
return sum(grouped_parameter)
annotation = function_modifiers.parameterize(
replace_just_literal_parameter={
"grouped_parameter": group(source("foo"), source("bar"), source("baz"))
}
)
(node_,) = annotation.expand_node(node.Node.from_fn(add_n), {}, concat)
assert node_.input_types == {
"foo": (int, DependencyType.REQUIRED),
"bar": (int, DependencyType.REQUIRED),
"baz": (int, DependencyType.REQUIRED),
}
assert node_.callable(foo=1, bar=2, baz=3) == 6
def test_parameterized_validate_group():
def add_n(grouped_parameter: List[int]) -> int:
return sum(grouped_parameter)
annotation = function_modifiers.parameterize(
replace_just_literal_parameter={
"grouped_parameter": group(source("foo"), source("bar"), source("baz"))
}
)
annotation.validate(add_n)
@pytest.mark.parametrize("annotation", [list, int, pd.Series, float])
def test_parameterized_validate_group_fails(annotation):
def add_n(grouped_parameter: annotation) -> int:
return sum(grouped_parameter)
annotation = function_modifiers.parameterize(
replace_just_literal_parameter={
"grouped_parameter": group(source("foo"), source("bar"), source("baz"))
}
)
with pytest.raises(base.InvalidDecoratorException):
annotation.validate(add_n)
def test_inject_list():
annotation = function_modifiers.inject(nums=group(value(1), value(2), value(3)))
def summation(nums: List[int]) -> int:
return sum(nums)
(node_,) = annotation.expand_node(node.Node.from_fn(summation), {}, summation)
assert node_() == 6
def test_inject_dict():
annotation = function_modifiers.inject(nums=group(one=value(1), two=value(2), three=value(3)))
def summation(nums: Dict[str, int]) -> int:
return nums["one"] + nums["two"] + nums["three"]
(node_,) = annotation.expand_node(node.Node.from_fn(summation), {}, summation)
assert node_() == 6
def test_inject_list_source():
annotation = function_modifiers.inject(
nums=group(source("one_source"), source("two_source"), source("three_source"))
)
def summation(nums: List[int]) -> int:
return sum(nums)
(node_,) = annotation.expand_node(node.Node.from_fn(summation), {}, summation)
assert node_(one_source=1, two_source=2, three_source=3) == 6
def test_inject_dict_source():
annotation = function_modifiers.inject(
nums=group(one=source("one_source"), two=source("two_source"), three=source("three_source"))
)
def summation(nums: Dict[str, int]) -> int:
return nums["one"] + nums["two"] + nums["three"]
(node_,) = annotation.expand_node(node.Node.from_fn(summation), {}, summation)
assert node_(one_source=1, two_source=2, three_source=3) == 6
def test_inject_multiple_things():
def contrived_function(
int_list: List[int],
int_dict: Dict[str, int],
int_value_injected: int,
int_value_not_injected: int,
) -> int:
return sum(int_list) + sum(int_dict.values()) + int_value_injected + int_value_not_injected
annotation = function_modifiers.inject(
int_list=group(value(1), value(2), source("three_source")),
int_dict=group(one=value(4), two=value(5), three=source("six_source")),
int_value_injected=value(7),
)
(node_,) = annotation.expand_node(node.Node.from_fn(contrived_function), {}, contrived_function)
assert node_(int_value_not_injected=8, three_source=3, six_source=6) == 8 * (8 + 1) // 2
@pytest.mark.parametrize(
"annotated_type,cls,expected",
[
(List[int], GroupedListDependency, int),
(List[List[int]], GroupedListDependency, List[int]),
(List[pd.Series], GroupedListDependency, pd.Series),
(Dict[str, pd.Series], GroupedDictDependency, pd.Series),
(Dict[str, List[int]], GroupedDictDependency, List[int]),
(Dict[str, int], GroupedDictDependency, int),
(Optional[Dict[str, int]], GroupedDictDependency, int),
(Optional[List[int]], GroupedListDependency, int),
],
)
def test_resolve_dependency_type_happy(
annotated_type: Type[Type], cls: Type[GroupedDependency], expected: Type[Type]
):
assert cls.resolve_dependency_type(annotated_type, "test") == expected
@pytest.mark.parametrize(
"annotated_type,cls",
[
(int, GroupedDictDependency),
(int, GroupedListDependency),
(List[int], GroupedDictDependency),
(Dict[str, int], GroupedListDependency),
(pd.Series, GroupedDictDependency),
(pd.Series, GroupedListDependency),
(pd.DataFrame, GroupedDictDependency),
(pd.DataFrame, GroupedListDependency),
(Dict[int, str], GroupedDictDependency),
],
)
def test_resolve_dependency_type_sad(annotated_type: Type[Type], cls: Type[GroupedDependency]):
with pytest.raises(base.InvalidDecoratorException):
cls.resolve_dependency_type(annotated_type, "test")
def test_inject_misconfigured_param_type_list():
def foo(x: int) -> int:
return x
annotation = function_modifiers.inject(x=group(value(1), value(2)))
with pytest.raises(base.InvalidDecoratorException):
annotation.validate(foo)
def test_inject_misconfigured_param_type_dict():
def foo(x: int) -> int:
return x
annotation = function_modifiers.inject(x=group(foo=value(1), bar=value(2)))
with pytest.raises(base.InvalidDecoratorException):
annotation.validate(foo)
@pytest.mark.skipif(
sys.version_info < (3, 9, 0), reason="Stricter type-checking only works on python 3.9+"
)
def test_inject_misconfigured_param_untyped_generic_list():
def foo(x: List) -> int:
return sum(x)
annotation = function_modifiers.inject(x=group(value(1), value(2)))
with pytest.raises(base.InvalidDecoratorException):
annotation.validate(foo)
def test_inject_misconfigured_param_untyped_generic_dict():
def foo(x: Dict) -> int:
return sum(x)
annotation = function_modifiers.inject(x=group(a=value(1), b=value(2)))
with pytest.raises(base.InvalidDecoratorException):
annotation.validate(foo)
def test_inject_validate_with_docstring():
def foo(x: int) -> int:
"""Docstring..."""
return x
annotation = function_modifiers.inject(x=value(1))
annotation.validate(foo)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_.documentation == "Docstring..."
def test_inject_validate_with_docstring_replacement():
def foo(x: int) -> int:
"""Docstring. x={x} is injected."""
return x
annotation = function_modifiers.inject(x=value(1))
annotation.validate(foo)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_.documentation == "Docstring. x=1 is injected."
def test_parameterize_repeated_sources():
def foo(x: int, y: int) -> int:
return x + y
annotation = function_modifiers.parameterize(
foo=dict(x=source("x"), y=source("x")),
)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_(x=1) == 2
def test_parameterize_repeated_sources_with_group():
def foo(x: List[int]) -> int:
return sum(x)
annotation = function_modifiers.parameterize(
foo=dict(x=group(source("x"), source("x"))),
)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_(x=1) == 2
def test_parameterize_with_parallelizable():
def foo(n: int) -> Parallelizable[int]:
for i in range(n):
yield i
annotation = function_modifiers.parameterize(
foo=dict(x=source("x"), y=source("x")),
)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_.node_role == node.NodeType.EXPAND
def test_parameterize_with_collect():
def foo(n: Collect[int]) -> int:
return sum(n)
annotation = function_modifiers.parameterize(
foo=dict(x=source("x"), y=source("x")),
)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_.node_role == node.NodeType.COLLECT