blob: 0cdedff72c5cb68c6f902671388433719be2b054 [file] [log] [blame]
import inspect
from typing import Any, Dict, List, Set
import numpy as np
import pandas as pd
import pytest
from hamilton import function_modifiers, function_modifiers_base, models, node
from hamilton.data_quality.base import DataValidationError, ValidationResult
from hamilton.function_modifiers import (
DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG,
IS_DATA_VALIDATOR_TAG,
LiteralDependency,
UpstreamDependency,
check_output,
check_output_custom,
does,
ensure_function_empty,
source,
value,
)
from hamilton.node import DependencyType
from tests.resources.dq_dummy_examples import (
DUMMY_VALIDATORS_FOR_TESTING,
SampleDataValidator2,
SampleDataValidator3,
)
def test_parametrized_invalid_params():
annotation = function_modifiers.parameterize_values(
parameter="non_existant",
assigned_output={("invalid_node_name", "invalid_doc"): "invalid_value"},
)
def no_param_node():
pass
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(no_param_node)
def wrong_param_node(valid_value):
pass
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(wrong_param_node)
def test_parametrized_single_param_breaks_without_docs():
with pytest.raises(function_modifiers.InvalidDecoratorException):
function_modifiers.parameterize_values(
parameter="parameter", assigned_output={"only_node_name": "only_value"}
)
def test_parametrized_single_param():
annotation = function_modifiers.parameterize_values(
parameter="parameter", assigned_output={("only_node_name", "only_doc"): "only_value"}
)
def identity(parameter: Any) -> Any:
return parameter
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 1
assert nodes[0].name == "only_node_name"
assert nodes[0].type == Any
assert nodes[0].documentation == "only_doc"
called = nodes[0].callable()
assert called == "only_value"
def test_parametrized_single_param_expanded():
annotation = function_modifiers.parameterize_values(
parameter="parameter",
assigned_output={("node_name_1", "doc1"): "value_1", ("node_value_2", "doc2"): "value_2"},
)
def identity(parameter: Any) -> Any:
return parameter
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
called_1 = nodes[0].callable()
called_2 = nodes[1].callable()
assert nodes[0].documentation == "doc1"
assert nodes[1].documentation == "doc2"
assert called_1 == "value_1"
assert called_2 == "value_2"
def test_parametrized_with_multiple_params():
annotation = function_modifiers.parameterize_values(
parameter="parameter",
assigned_output={("node_name_1", "doc1"): "value_1", ("node_value_2", "doc2"): "value_2"},
)
def identity(parameter: Any, static: Any) -> Any:
return parameter, static
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
called_1 = nodes[0].callable(static="static_param")
called_2 = nodes[1].callable(static="static_param")
assert called_1 == ("value_1", "static_param")
assert called_2 == ("value_2", "static_param")
def test_parametrized_input():
annotation = function_modifiers.parametrized_input(
parameter="parameter",
variable_inputs={
"input_1": ("test_1", "Function with first column as input"),
"input_2": ("test_2", "Function with second column as input"),
},
)
def identity(parameter: Any, static: Any) -> Any:
return parameter, static
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
nodes = sorted(nodes, key=lambda n: n.name)
assert [n.name for n in nodes] == ["test_1", "test_2"]
assert set(nodes[0].input_types.keys()) == {"static", "input_1"}
assert set(nodes[1].input_types.keys()) == {"static", "input_2"}
def test_parametrize_sources_validate_param_name():
"""Tests validate function of parameterize_sources capturing bad param name usage."""
annotation = function_modifiers.parameterize_sources(
parameterization={
"test_1": dict(parameterfoo="input_1"),
}
)
def identity(parameter1: str, parameter2: str, static: str) -> str:
"""Function with {parameter1} as first input"""
return parameter1 + parameter2 + static
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(identity)
def test_parametrized_inputs_validate_reserved_param():
"""Tests validate function of parameterize_inputs catching reserved param usage."""
annotation = function_modifiers.parameterize_sources(
**{
"test_1": dict(parameter2="input_1"),
}
)
def identity(output_name: str, parameter2: str, static: str) -> str:
"""Function with {parameter2} as second input"""
return output_name + parameter2 + static
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(identity)
def test_parametrized_inputs_validate_bad_doc_string():
"""Tests validate function of parameterize_inputs catching bad doc string."""
annotation = function_modifiers.parameterize_sources(
**{
"test_1": dict(parameter2="input_1"),
}
)
def identity(output_name: str, parameter2: str, static: str) -> str:
"""Function with {foo} as second input"""
return output_name + parameter2 + static
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(identity)
def test_parametrized_inputs():
annotation = function_modifiers.parameterize_sources(
**{
"test_1": dict(parameter1="input_1", parameter2="input_2"),
"test_2": dict(parameter1="input_2", parameter2="input_1"),
}
)
def identity(parameter1: str, parameter2: str, static: str) -> str:
"""Function with {parameter1} as first input"""
return parameter1 + parameter2 + static
nodes = annotation.expand_node(node.Node.from_fn(identity), {}, identity)
assert len(nodes) == 2
nodes = sorted(nodes, key=lambda n: n.name)
assert [n.name for n in nodes] == ["test_1", "test_2"]
assert set(nodes[0].input_types.keys()) == {"static", "input_1", "input_2"}
assert nodes[0].documentation == "Function with input_1 as first input"
assert set(nodes[1].input_types.keys()) == {"static", "input_1", "input_2"}
assert nodes[1].documentation == "Function with input_2 as first input"
result1 = nodes[0].callable(**{"input_1": "1", "input_2": "2", "static": "3"})
assert result1 == "123"
result2 = nodes[1].callable(**{"input_1": "1", "input_2": "2", "static": "3"})
assert result2 == "213"
def test_invalid_column_extractor():
annotation = function_modifiers.extract_columns("dummy_column")
def no_param_node() -> int:
pass
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(no_param_node)
def test_extract_columns_invalid_passing_list_to_column_extractor():
"""Ensures that people cannot pass in a list."""
with pytest.raises(function_modifiers.InvalidDecoratorException):
function_modifiers.extract_columns(["a", "b", "c"])
def test_extract_columns_empty_args():
"""Tests that we fail on empty arguments."""
with pytest.raises(function_modifiers.InvalidDecoratorException):
function_modifiers.extract_columns()
def test_extract_columns_happy():
"""Tests that we are happy with good arguments."""
function_modifiers.extract_columns(*["a", ("b", "some doc"), "c"])
def test_valid_column_extractor():
"""Tests that things work, and that you can provide optional documentation."""
annotation = function_modifiers.extract_columns("col_1", ("col_2", "col2_doc"))
def dummy_df_generator() -> pd.DataFrame:
"""dummy doc"""
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
nodes = list(
annotation.expand_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator)
)
assert len(nodes) == 3
assert nodes[0] == node.Node(
name=dummy_df_generator.__name__,
typ=pd.DataFrame,
doc_string=dummy_df_generator.__doc__,
callabl=dummy_df_generator,
tags={"module": "tests.test_function_modifiers"},
)
assert nodes[1].name == "col_1"
assert nodes[1].type == pd.Series
assert nodes[1].documentation == "dummy doc" # we default to base function doc.
assert nodes[1].input_types == {
dummy_df_generator.__name__: (pd.DataFrame, DependencyType.REQUIRED)
}
assert nodes[2].name == "col_2"
assert nodes[2].type == pd.Series
assert nodes[2].documentation == "col2_doc"
assert nodes[2].input_types == {
dummy_df_generator.__name__: (pd.DataFrame, DependencyType.REQUIRED)
}
def test_column_extractor_fill_with():
def dummy_df() -> pd.DataFrame:
"""dummy doc"""
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
annotation = function_modifiers.extract_columns("col_3", fill_with=0)
original_node, extracted_column_node = annotation.expand_node(
node.Node.from_fn(dummy_df), {}, dummy_df
)
original_df = original_node.callable()
extracted_column = extracted_column_node.callable(dummy_df=original_df)
pd.testing.assert_series_equal(extracted_column, pd.Series([0, 0, 0, 0]), check_names=False)
pd.testing.assert_series_equal(
original_df["col_3"], pd.Series([0, 0, 0, 0]), check_names=False
) # it has to be in there now
def test_column_extractor_no_fill_with():
def dummy_df_generator() -> pd.DataFrame:
"""dummy doc"""
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
annotation = function_modifiers.extract_columns("col_3")
nodes = list(
annotation.expand_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator)
)
with pytest.raises(function_modifiers.InvalidDecoratorException):
nodes[1].callable(dummy_df_generator=dummy_df_generator())
def test_no_code_validator():
def no_code():
pass
def no_code_with_docstring():
"""This should still show up as having no code, even though it has a docstring"""
pass
def yes_code():
"""This should show up as having no code"""
a = 0
return a
ensure_function_empty(no_code)
ensure_function_empty(no_code_with_docstring)
with pytest.raises(function_modifiers.InvalidDecoratorException):
ensure_function_empty(yes_code)
# Functions for @does -- these are the functions we're "replacing"
def _no_params() -> int:
pass
def _one_param(a: int) -> int:
pass
def _two_params(a: int, b: int) -> int:
pass
def _three_params(a: int, b: int, c: int) -> int:
pass
def _three_params_with_defaults(a: int, b: int = 1, c: int = 2) -> int:
pass
# functions we can/can't replace them with
def _empty() -> int:
return 1
def _kwargs(**kwargs: int) -> int:
return sum(kwargs.values())
def _kwargs_with_a(a: int, **kwargs: int) -> int:
return a + sum(kwargs.values())
def _just_a(a: int) -> int:
return a
def _just_b(b: int) -> int:
return b
def _a_b_c(a: int, b: int, c: int) -> int:
return a + b + c
@pytest.mark.parametrize(
"fn,replace_with,argument_mapping,matches",
[
(_no_params, _empty, {}, True),
(_no_params, _kwargs, {}, True),
(_no_params, _kwargs_with_a, {}, False),
(_no_params, _just_a, {}, False),
(_no_params, _a_b_c, {}, False),
(_one_param, _empty, {}, False),
(_one_param, _kwargs, {}, True),
(_one_param, _kwargs_with_a, {}, True),
(_one_param, _just_a, {}, True),
(_one_param, _just_b, {}, False),
(_one_param, _just_b, {"b": "a"}, True), # Replacing a with b makes the signatures match
(_one_param, _just_b, {"c": "a"}, False), # Replacing a with b makes the signatures match
(_two_params, _empty, {}, False),
(_two_params, _kwargs, {}, True),
(_two_params, _kwargs_with_a, {}, True), # b gets fed to kwargs
(_two_params, _kwargs_with_a, {"foo": "b"}, True), # Any kwargs work
(_two_params, _kwargs_with_a, {"bar": "a"}, False), # No param bar
(_two_params, _just_a, {}, False),
(_two_params, _just_b, {}, False),
(_three_params, _a_b_c, {}, True),
(_three_params, _a_b_c, {"d": "a"}, False),
(_three_params, _a_b_c, {}, True),
(_three_params, _a_b_c, {"a": "b", "b": "a"}, True), # Weird case but why not?
(_three_params, _kwargs_with_a, {}, True),
(_three_params_with_defaults, _a_b_c, {}, True),
(_three_params_with_defaults, _a_b_c, {"d": "a"}, False),
(_three_params_with_defaults, _a_b_c, {}, True),
],
)
def test_ensure_function_signatures_compatible(fn, replace_with, argument_mapping, matches):
assert (
does.test_function_signatures_compatible(
inspect.signature(fn), inspect.signature(replace_with), argument_mapping
)
== matches
)
def test_does_function_modifier():
def sum_(**kwargs: int) -> int:
return sum(kwargs.values())
def to_modify(param1: int, param2: int) -> int:
"""This sums the inputs it gets..."""
pass
annotation = does(sum_)
node = annotation.generate_node(to_modify, {})
assert node.name == "to_modify"
assert node.callable(param1=1, param2=1) == 2
assert node.documentation == to_modify.__doc__
def test_does_function_modifier_complex_types():
def setify(**kwargs: List[int]) -> Set[int]:
return set(sum(kwargs.values(), []))
def to_modify(param1: List[int], param2: List[int]) -> int:
"""This sums the inputs it gets..."""
pass
annotation = does(setify)
node = annotation.generate_node(to_modify, {})
assert node.name == "to_modify"
assert node.callable(param1=[1, 2, 3], param2=[4, 5, 6]) == {1, 2, 3, 4, 5, 6}
assert node.documentation == to_modify.__doc__
def test_does_function_modifier_optionals():
def sum_(param0: int, **kwargs: int) -> int:
return sum(kwargs.values())
def to_modify(param0: int, param1: int = 1, param2: int = 2) -> int:
"""This sums the inputs it gets..."""
pass
annotation = does(sum_)
node_ = annotation.generate_node(to_modify, {})
assert node_.name == "to_modify"
assert node_.input_types["param0"][1] == DependencyType.REQUIRED
assert node_.input_types["param1"][1] == DependencyType.OPTIONAL
assert node_.input_types["param2"][1] == DependencyType.OPTIONAL
assert node_.callable(param0=0) == 3
assert node_.callable(param0=0, param1=0, param2=0) == 0
assert node_.documentation == to_modify.__doc__
def test_does_with_argument_mapping():
def _sum_multiply(param0: int, param1: int, param2: int) -> int:
return param0 + param1 * param2
def to_modify(parama: int, paramb: int = 1, paramc: int = 2) -> int:
"""This sums the inputs it gets..."""
pass
annotation = does(_sum_multiply, param0="parama", param1="paramb", param2="paramc")
node = annotation.generate_node(to_modify, {})
assert node.name == "to_modify"
assert node.input_types["parama"][1] == DependencyType.REQUIRED
assert node.input_types["paramb"][1] == DependencyType.OPTIONAL
assert node.input_types["paramc"][1] == DependencyType.OPTIONAL
assert node.callable(parama=0) == 2
assert node.callable(parama=0, paramb=1, paramc=2) == 2
assert node.callable(parama=1, paramb=4) == 9
assert node.documentation == to_modify.__doc__
def test_model_modifier():
config = {
"my_column_model_params": {
"col_1": 0.5,
"col_2": 0.5,
}
}
class LinearCombination(models.BaseModel):
def get_dependents(self) -> List[str]:
return list(self.config_parameters.keys())
def predict(self, **columns: pd.Series) -> pd.Series:
return sum(
self.config_parameters[column_name] * column
for column_name, column in columns.items()
)
def my_column() -> pd.Series:
"""Column that will be annotated by a model"""
pass
annotation = function_modifiers.model(LinearCombination, "my_column_model_params")
annotation.validate(my_column)
model_node = annotation.generate_node(my_column, config)
assert model_node.input_types["col_1"][0] == model_node.input_types["col_2"][0] == pd.Series
assert model_node.type == pd.Series
pd.testing.assert_series_equal(
model_node.callable(col_1=pd.Series([1]), col_2=pd.Series([2])), pd.Series([1.5])
)
def bad_model(col_1: pd.Series, col_2: pd.Series) -> pd.Series:
return col_1 * 0.5 + col_2 * 0.5
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(bad_model)
def test_sanitize_function_name():
assert function_modifiers_base.sanitize_function_name("fn_name__v2") == "fn_name"
assert function_modifiers_base.sanitize_function_name("fn_name") == "fn_name"
def test_config_modifier_validate():
def valid_fn() -> int:
pass
def valid_fn__this_is_also_valid() -> int:
pass
function_modifiers.config.when(key="value").validate(valid_fn__this_is_also_valid)
function_modifiers.config.when(key="value").validate(valid_fn)
def invalid_function__() -> int:
pass
with pytest.raises(function_modifiers.InvalidDecoratorException):
function_modifiers.config.when(key="value").validate(invalid_function__)
def test_config_when():
def config_when_fn() -> int:
pass
annotation = function_modifiers.config.when(key="value")
assert annotation.resolve(config_when_fn, {"key": "value"}) is not None
assert annotation.resolve(config_when_fn, {"key": "wrong_value"}) is None
def test_config_when_not():
def config_when_not_fn() -> int:
pass
annotation = function_modifiers.config.when_not(key="value")
assert annotation.resolve(config_when_not_fn, {"key": "other_value"}) is not None
assert annotation.resolve(config_when_not_fn, {"key": "value"}) is None
def test_config_when_in():
def config_when_in_fn() -> int:
pass
annotation = function_modifiers.config.when_in(key=["valid_value", "another_valid_value"])
assert annotation.resolve(config_when_in_fn, {"key": "valid_value"}) is not None
assert annotation.resolve(config_when_in_fn, {"key": "another_valid_value"}) is not None
assert annotation.resolve(config_when_in_fn, {"key": "not_a_valid_value"}) is None
def test_config_when_not_in():
def config_when_not_in_fn() -> int:
pass
annotation = function_modifiers.config.when_not_in(
key=["invalid_value", "another_invalid_value"]
)
assert annotation.resolve(config_when_not_in_fn, {"key": "invalid_value"}) is None
assert annotation.resolve(config_when_not_in_fn, {"key": "another_invalid_value"}) is None
assert annotation.resolve(config_when_not_in_fn, {"key": "valid_value"}) is not None
def test_config_name_resolution():
def fn__v2() -> int:
pass
annotation = function_modifiers.config.when(key="value")
assert annotation.resolve(fn__v2, {"key": "value"}).__name__ == "fn"
def test_config_when_with_custom_name():
def config_when_fn() -> int:
pass
annotation = function_modifiers.config.when(key="value", name="new_function_name")
assert annotation.resolve(config_when_fn, {"key": "value"}).__name__ == "new_function_name"
@pytest.mark.parametrize(
"fields",
[
(None), # empty
("string_input"), # not a dict
(["string_input"]), # not a dict
({}), # empty dict
({1: "string", "field": str}), # invalid dict
({"field": lambda x: x, "field2": int}), # invalid dict
],
)
def test_extract_fields_constructor_errors(fields):
with pytest.raises(function_modifiers.InvalidDecoratorException):
function_modifiers.extract_fields(fields)
@pytest.mark.parametrize(
"fields",
[
({"field": np.ndarray, "field2": str}),
({"field": dict, "field2": int, "field3": list, "field4": float, "field5": str}),
],
)
def test_extract_fields_constructor_happy(fields):
"""Tests that we are happy with good arguments."""
function_modifiers.extract_fields(fields)
@pytest.mark.parametrize(
"return_type",
[
(dict),
(Dict),
(Dict[str, str]),
(Dict[str, Any]),
],
)
def test_extract_fields_validate_happy(return_type):
def return_dict() -> return_type:
return {}
annotation = function_modifiers.extract_fields({"test": int})
annotation.validate(return_dict)
@pytest.mark.parametrize("return_type", [(int), (list), (np.ndarray), (pd.DataFrame)])
def test_extract_fields_validate_errors(return_type):
def return_dict() -> return_type:
return {}
annotation = function_modifiers.extract_fields({"test": int})
with pytest.raises(function_modifiers.InvalidDecoratorException):
annotation.validate(return_dict)
def test_valid_extract_fields():
"""Tests whole extract_fields decorator."""
annotation = function_modifiers.extract_fields(
{"col_1": list, "col_2": int, "col_3": np.ndarray}
)
def dummy_dict_generator() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])}
nodes = list(
annotation.expand_node(node.Node.from_fn(dummy_dict_generator), {}, dummy_dict_generator)
)
assert len(nodes) == 4
assert nodes[0] == node.Node(
name=dummy_dict_generator.__name__,
typ=dict,
doc_string=dummy_dict_generator.__doc__,
callabl=dummy_dict_generator,
tags={"module": "tests.test_function_modifiers"},
)
assert nodes[1].name == "col_1"
assert nodes[1].type == list
assert nodes[1].documentation == "dummy doc" # we default to base function doc.
assert nodes[1].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)}
assert nodes[2].name == "col_2"
assert nodes[2].type == int
assert nodes[2].documentation == "dummy doc"
assert nodes[2].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)}
assert nodes[3].name == "col_3"
assert nodes[3].type == np.ndarray
assert nodes[3].documentation == "dummy doc"
assert nodes[3].input_types == {dummy_dict_generator.__name__: (dict, DependencyType.REQUIRED)}
def test_extract_fields_fill_with():
def dummy_dict() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])}
annotation = function_modifiers.extract_fields({"col_2": int, "col_4": float}, fill_with=1.0)
original_node, extracted_field_node, missing_field_node = annotation.expand_node(
node.Node.from_fn(dummy_dict), {}, dummy_dict
)
original_dict = original_node.callable()
extracted_field = extracted_field_node.callable(dummy_dict=original_dict)
missing_field = missing_field_node.callable(dummy_dict=original_dict)
assert extracted_field == 1
assert missing_field == 1.0
def test_extract_fields_no_fill_with():
def dummy_dict() -> dict:
"""dummy doc"""
return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])}
annotation = function_modifiers.extract_fields({"col_4": int})
nodes = list(annotation.expand_node(node.Node.from_fn(dummy_dict), {}, dummy_dict))
with pytest.raises(function_modifiers.InvalidDecoratorException):
nodes[1].callable(dummy_dict=dummy_dict())
def test_tags():
def dummy_tagged_function() -> int:
"""dummy doc"""
return 1
annotation = function_modifiers.tag(foo="bar", bar="baz")
node_ = annotation.decorate_node(node.Node.from_fn(dummy_tagged_function))
assert "foo" in node_.tags
assert "bar" in node_.tags
@pytest.mark.parametrize(
"key",
[
"hamilton", # Reserved key
"foo@", # Invalid identifier
"foo bar", # No spaces
"foo.bar+baz", # Invalid key, not a valid identifier
"" "...", # Empty not allowed # Empty elements not allowed
],
)
def test_tags_invalid_key(key):
assert not function_modifiers.tag._key_allowed(key)
@pytest.mark.parametrize(
"key",
[
"bar.foo",
"foo", # Invalid identifier
"foo.bar.baz", # Invalid key, not a valid identifier
],
)
def test_tags_valid_key(key):
assert function_modifiers.tag._key_allowed(key)
@pytest.mark.parametrize("value", [None, False, [], ["foo", "bar"]])
def test_tags_invalid_value(value):
assert not function_modifiers.tag._value_allowed(value)
def test_check_output_node_transform():
decorator = check_output(
importance="warn",
default_decorator_candidates=DUMMY_VALIDATORS_FOR_TESTING,
dataset_length=1,
dtype=np.int64,
)
def fn(input: pd.Series) -> pd.Series:
return input
node_ = node.Node.from_fn(fn)
subdag = decorator.transform_node(node_, config={}, fn=fn)
assert 4 == len(subdag)
subdag_as_dict = {node_.name: node_ for node_ in subdag}
assert sorted(subdag_as_dict.keys()) == [
"fn",
"fn_dummy_data_validator_2",
"fn_dummy_data_validator_3",
"fn_raw",
]
# TODO -- change when we change the naming scheme
assert subdag_as_dict["fn_raw"].input_types["input"][1] == DependencyType.REQUIRED
assert 3 == len(
subdag_as_dict["fn"].input_types
) # Three dependencies -- the two with DQ + the original
# The final function should take in everything but only use the raw results
assert (
subdag_as_dict["fn"].callable(
fn_raw="test",
fn_dummy_data_validator_2=ValidationResult(True, "", {}),
fn_dummy_data_validator_3=ValidationResult(True, "", {}),
)
== "test"
)
def test_check_output_custom_node_transform():
decorator = check_output_custom(
SampleDataValidator2(dataset_length=1, importance="warn"),
SampleDataValidator3(dtype=np.int64, importance="warn"),
)
def fn(input: pd.Series) -> pd.Series:
return input
node_ = node.Node.from_fn(fn)
subdag = decorator.transform_node(node_, config={}, fn=fn)
assert 4 == len(subdag)
subdag_as_dict = {node_.name: node_ for node_ in subdag}
assert sorted(subdag_as_dict.keys()) == [
"fn",
"fn_dummy_data_validator_2",
"fn_dummy_data_validator_3",
"fn_raw",
]
# TODO -- change when we change the naming scheme
assert subdag_as_dict["fn_raw"].input_types["input"][1] == DependencyType.REQUIRED
assert 3 == len(
subdag_as_dict["fn"].input_types
) # Three dependencies -- the two with DQ + the original
data_validators = [
value
for value in subdag_as_dict.values()
if value.tags.get("hamilton.data_quality.contains_dq_results", False)
]
assert len(data_validators) == 2 # One for each validator
first_validator, _ = data_validators
assert (
IS_DATA_VALIDATOR_TAG in first_validator.tags
and first_validator.tags[IS_DATA_VALIDATOR_TAG] is True
) # Validates that all the required tags are included
assert (
DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG in first_validator.tags
and first_validator.tags[DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG] == "fn"
)
# The final function should take in everything but only use the raw results
assert (
subdag_as_dict["fn"].callable(
fn_raw="test",
fn_dummy_data_validator_2=ValidationResult(True, "", {}),
fn_dummy_data_validator_3=ValidationResult(True, "", {}),
)
== "test"
)
def test_check_output_custom_node_transform_raises_exception_with_failure():
decorator = check_output_custom(
SampleDataValidator2(dataset_length=1, importance="fail"),
SampleDataValidator3(dtype=np.int64, importance="fail"),
)
def fn(input: pd.Series) -> pd.Series:
return input
node_ = node.Node.from_fn(fn)
subdag = decorator.transform_node(node_, config={}, fn=fn)
assert 4 == len(subdag)
subdag_as_dict = {node_.name: node_ for node_ in subdag}
with pytest.raises(DataValidationError):
subdag_as_dict["fn"].callable(
fn_raw=pd.Series([1.0, 2.0, 3.0]),
fn_dummy_data_validator_2=ValidationResult(False, "", {}),
fn_dummy_data_validator_3=ValidationResult(False, "", {}),
)
def test_check_output_custom_node_transform_layered():
decorator_1 = check_output_custom(
SampleDataValidator2(dataset_length=1, importance="warn"),
)
decorator_2 = check_output_custom(SampleDataValidator3(dtype=np.int64, importance="warn"))
def fn(input: pd.Series) -> pd.Series:
return input
node_ = node.Node.from_fn(fn)
subdag_first_transformation = decorator_1.transform_dag([node_], config={}, fn=fn)
subdag_second_transformation = decorator_2.transform_dag(
subdag_first_transformation, config={}, fn=fn
)
# One node for each dummy validator
# One final node
# One intermediate node for each of the functions (E.G. raw)
# TODO -- ensure that the intermediate nodes don't share names
assert 5 == len(subdag_second_transformation)
def test_data_quality_constants_for_api_consistency():
# simple tests to test data quality constants remain the same
assert IS_DATA_VALIDATOR_TAG == "hamilton.data_quality.contains_dq_results"
assert DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG == "hamilton.data_quality.source_node"
def concat(upstream_parameter: str, literal_parameter: str) -> Any:
"""Concatenates {upstream_parameter} with literal_parameter"""
return f"{upstream_parameter}{literal_parameter}"
def test_parametrized_full_no_replacement():
annotation = function_modifiers.parameterize(replace_no_parameters={})
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.callable(upstream_parameter="foo", literal_parameter="bar") == "foobar"
assert node_.input_types == {
"literal_parameter": (str, DependencyType.REQUIRED),
"upstream_parameter": (str, DependencyType.REQUIRED),
}
assert node_.documentation == concat.__doc__.format(upstream_parameter="upstream_parameter")
def test_parametrized_full_replace_just_upstream():
annotation = function_modifiers.parameterize(
replace_just_upstream_parameter={"upstream_parameter": source("foo_source")},
)
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.input_types == {
"literal_parameter": (str, DependencyType.REQUIRED),
"foo_source": (str, DependencyType.REQUIRED),
}
assert node_.callable(foo_source="foo", literal_parameter="bar") == "foobar"
assert node_.documentation == concat.__doc__.format(upstream_parameter="foo_source")
def test_parametrized_full_replace_just_literal():
annotation = function_modifiers.parameterize(
replace_just_literal_parameter={"literal_parameter": value("bar")}
)
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.input_types == {"upstream_parameter": (str, DependencyType.REQUIRED)}
assert node_.callable(upstream_parameter="foo") == "foobar"
assert node_.documentation == concat.__doc__.format(upstream_parameter="upstream_parameter")
def test_parametrized_full_replace_both():
annotation = function_modifiers.parameterize(
replace_both_parameters={
"upstream_parameter": source("foo_source"),
"literal_parameter": value("bar"),
}
)
(node_,) = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert node_.input_types == {"foo_source": (str, DependencyType.REQUIRED)}
assert node_.callable(foo_source="foo") == "foobar"
assert node_.documentation == concat.__doc__.format(upstream_parameter="foo_source")
def test_parametrized_full_multiple_replacements():
args = dict(
replace_no_parameters=({}, "fn with no parameters replaced"),
replace_just_upstream_parameter=(
{"upstream_parameter": source("foo_source")},
"fn with upstream_parameter set to node foo",
),
replace_just_literal_parameter=(
{"literal_parameter": value("bar")},
"fn with upstream_parameter set to node foo",
),
replace_both_parameters=(
{"upstream_parameter": source("foo_source"), "literal_parameter": value("bar")},
"fn with both parameters replaced",
),
)
annotation = function_modifiers.parameterize(**args)
nodes = annotation.expand_node(node.Node.from_fn(concat), {}, concat)
assert len(nodes) == 4
# test out that documentation is assigned correctly
assert [node_.documentation for node_ in nodes] == [args[node_.name][1] for node_ in nodes]
@pytest.mark.parametrize(
"upstream_source,expected",
[("foo", UpstreamDependency("foo")), (UpstreamDependency("bar"), UpstreamDependency("bar"))],
)
def test_upstream(upstream_source, expected):
assert source(upstream_source) == expected
@pytest.mark.parametrize(
"literal_value,expected",
[
("foo", LiteralDependency("foo")),
(LiteralDependency("foo"), LiteralDependency("foo")),
(1, LiteralDependency(1)),
],
)
def test_literal(literal_value, expected):
assert value(literal_value) == expected