blob: cba3470f2f88be846c1231a96d00a2ab35e9c6f4 [file] [log] [blame]
import dataclasses
from collections import Counter
from typing import Any, Collection, Dict, List, Tuple, Type
import pandas as pd
import pytest
from hamilton import ad_hoc_utils, base, driver, graph, node
from hamilton.function_modifiers import base as fm_base
from hamilton.function_modifiers import save_to, source, value
from hamilton.function_modifiers.adapters import (
LoadFromDecorator,
SaveToDecorator,
load_from,
resolve_adapter_class,
resolve_kwargs,
)
from hamilton.function_modifiers.base import DefaultNodeCreator
from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.registry import LOADER_REGISTRY
def test_default_adapters_are_available():
assert len(LOADER_REGISTRY) > 0
def test_default_adapters_are_registered_once():
assert "json" in LOADER_REGISTRY
count_unique = {
key: Counter([value.__class__.__qualname__ for value in values])
for key, values in LOADER_REGISTRY.items()
}
for key, value_ in count_unique.items():
for impl, count in value_.items():
assert count == 1, (
f"Adapter registered multiple times for {impl}. This should not"
f" happen, as items should just be registered once."
)
@dataclasses.dataclass
class MockDataLoader(DataLoader):
required_param: int
required_param_2: int
required_param_3: str
default_param: int = 4
default_param_2: int = 5
default_param_3: str = "6"
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]
def load_data(self, type_: Type[int]) -> Tuple[int, Dict[str, Any]]:
return ..., {"required_param": self.required_param, "default_param": self.default_param}
@classmethod
def name(cls) -> str:
return "mock"
def test_load_from_decorator():
def fn(data: int) -> int:
return data
decorator = LoadFromDecorator(
[MockDataLoader],
required_param=value("1"),
required_param_2=value("2"),
required_param_3=value("3"),
)
nodes_raw = DefaultNodeCreator().generate_nodes(fn, {})
nodes = decorator.transform_dag(nodes_raw, {}, fn)
assert len(nodes) == 3
nodes_by_name = {node_.name: node_ for node_ in nodes}
assert len(nodes_by_name) == 3
assert "fn" in nodes_by_name
assert nodes_by_name["fn.load_data.data"].tags == {
"hamilton.data_loader.source": "mock",
"hamilton.data_loader": True,
"hamilton.data_loader.has_metadata": True,
"hamilton.data_loader.node": "data",
"hamilton.data_loader.classname": MockDataLoader.__qualname__,
}
assert nodes_by_name["fn.select_data.data"].tags == {
"hamilton.data_loader.source": "mock",
"hamilton.data_loader": True,
"hamilton.data_loader.has_metadata": False,
"hamilton.data_loader.node": "data",
"hamilton.data_loader.classname": MockDataLoader.__qualname__,
}
def test_load_from_decorator_resolve_kwargs():
kwargs = dict(
required_param=source("1"),
required_param_2=value(2),
required_param_3=value("3"),
default_param=source("4"),
default_param_2=value(5),
)
dependency_kwargs, literal_kwargs = resolve_kwargs(kwargs)
assert dependency_kwargs == {"required_param": "1", "default_param": "4"}
assert literal_kwargs == {"required_param_2": 2, "required_param_3": "3", "default_param_2": 5}
def test_load_from_decorator_validate_succeeds():
decorator = LoadFromDecorator(
[MockDataLoader],
required_param=source("1"),
required_param_2=value(2),
required_param_3=value("3"),
)
def fn(injected_data: int) -> int:
return injected_data
decorator.validate(fn)
def test_load_from_decorator_validate_succeeds_with_inject():
decorator = LoadFromDecorator(
[MockDataLoader],
inject_="injected_data",
required_param=source("1"),
required_param_2=value(2),
required_param_3=value("3"),
)
def fn(injected_data: int, dependent_data: int) -> int:
return injected_data + dependent_data
decorator.validate(fn)
def test_load_from_decorator_validate_fails_dont_know_which_param_to_inject():
decorator = LoadFromDecorator(
[MockDataLoader],
required_param=source("1"),
required_param_2=value(2),
required_param_3=value("3"),
)
def fn(injected_data: int, other_possible_injected_data: int) -> int:
return injected_data + other_possible_injected_data
with pytest.raises(fm_base.InvalidDecoratorException):
decorator.validate(fn)
def test_load_from_decorator_validate_fails_inject_not_in_fn():
decorator = LoadFromDecorator(
[MockDataLoader],
inject_="injected_data",
required_param=source("1"),
required_param_2=value(2),
required_param_3=value("3"),
)
def fn(dependent_data: int) -> int:
return dependent_data
with pytest.raises(fm_base.InvalidDecoratorException):
decorator.validate(fn)
def test_load_from_decorator_validate_fails_inject_missing_param():
decorator = LoadFromDecorator(
[MockDataLoader],
required_param=source("1"),
required_param_2=value(2),
# This is commented out cause it'll be missing
# required_param_3=value("3"),
)
def fn(data: int) -> int:
return data
with pytest.raises(fm_base.InvalidDecoratorException):
decorator.validate(fn)
@dataclasses.dataclass
class StringDataLoader(DataLoader):
def load_data(self, type_: Type) -> Tuple[str, Dict[str, Any]]:
return "foo", {"loader": "string_data_loader"}
@classmethod
def name(cls) -> str:
return "dummy"
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [str]
@dataclasses.dataclass
class IntDataLoader(DataLoader):
def load_data(self, type_: Type) -> Tuple[int, Dict[str, Any]]:
return 1, {"loader": "int_data_loader"}
@classmethod
def name(cls) -> str:
return "dummy"
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]
@dataclasses.dataclass
class IntDataLoader2(DataLoader):
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]
def load_data(self, type_: Type) -> Tuple[int, Dict[str, Any]]:
return 2, {"loader": "int_data_loader_2"}
@classmethod
def name(cls) -> str:
return "dummy"
def test_validate_fails_incorrect_type():
decorator = LoadFromDecorator(
[StringDataLoader, IntDataLoader],
)
def fn_str_inject(injected_data: str) -> str:
return injected_data
def fn_int_inject(injected_data: int) -> int:
return injected_data
def fn_bool_inject(injected_data: bool) -> bool:
return injected_data
# This is valid as there is one parameter and its a type that the decorator supports
decorator.validate(fn_str_inject)
# This is valid as there is one parameter and its a type that the decorator supports
decorator.validate(fn_int_inject)
# This is invalid as there is one parameter and it is not a type that the decorator supports
with pytest.raises(fm_base.InvalidDecoratorException):
decorator.validate(fn_bool_inject)
def test_validate_selects_correct_type():
decorator = LoadFromDecorator(
[StringDataLoader, IntDataLoader],
)
def fn_str_inject(injected_data: str) -> str:
return injected_data
def fn_int_inject(injected_data: int) -> int:
return injected_data
def fn_bool_inject(injected_data: bool) -> bool:
return injected_data
# This is valid as there is one parameter and its a type that the decorator supports
decorator.validate(fn_str_inject)
# This is valid as there is one parameter and its a type that the decorator supports
decorator.validate(fn_int_inject)
# This is invalid as there is one parameter and it is not a type that the decorator supports
with pytest.raises(fm_base.InvalidDecoratorException):
decorator.validate(fn_bool_inject)
# Note that this tests an internal API, but we would like to test this to ensure
# class selection is correct
@pytest.mark.parametrize(
"type_,classes,correct_class",
[
(str, [StringDataLoader, IntDataLoader, IntDataLoader2], StringDataLoader),
(int, [StringDataLoader, IntDataLoader, IntDataLoader2], IntDataLoader2),
(int, [IntDataLoader2, IntDataLoader], IntDataLoader),
(int, [IntDataLoader, IntDataLoader2], IntDataLoader2),
(int, [StringDataLoader], None),
(str, [IntDataLoader], None),
(dict, [IntDataLoader], None),
],
)
def test_resolve_correct_loader_class(
type_: Type[Type], classes: List[Type[DataLoader]], correct_class: Type[DataLoader]
):
assert resolve_adapter_class(type_, classes) == correct_class
def test_decorator_validate():
decorator = LoadFromDecorator(
[StringDataLoader, IntDataLoader, IntDataLoader2],
)
def fn_str_inject(injected_data: str) -> str:
return injected_data
def fn_int_inject(injected_data: int) -> int:
return injected_data
def fn_bool_inject(injected_data: bool) -> bool:
return injected_data
# This is valid as there is one parameter and its a type that the decorator supports
decorator.validate(fn_str_inject)
decorator.validate(fn_int_inject)
# This is invalid as there is one parameter and it is not a type that the decorator supports
with pytest.raises(fm_base.InvalidDecoratorException):
decorator.validate(fn_bool_inject)
# End-to-end tests are probably cleanest
# We've done a bunch of tests of internal structures for other decorators,
# but that leaves the testing brittle
# We don't test the driver, we just use the function_graph to tests the nodes
def test_load_from_decorator_end_to_end():
@LoadFromDecorator(
[StringDataLoader, IntDataLoader, IntDataLoader2],
)
def fn_str_inject(injected_data: str) -> str:
return injected_data
config = {}
adapter = base.DefaultAdapter()
fg = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(fn_str_inject), config=config, adapter=adapter
)
result = fg.execute(inputs={}, nodes=fg.nodes.values())
assert result["fn_str_inject"] == "foo"
assert result["fn_str_inject.load_data.injected_data"] == (
"foo",
{"loader": "string_data_loader"},
)
# End-to-end tests are probably cleanest
# We've done a bunch of tests of internal structures for other decorators,
# but that leaves the testing brittle
# We don't test the driver, we just use the function_graph to tests the nodes
def test_load_from_decorator_end_to_end_with_multiple():
@LoadFromDecorator(
[StringDataLoader, IntDataLoader, IntDataLoader2],
inject_="injected_data_1",
)
@LoadFromDecorator(
[StringDataLoader, IntDataLoader, IntDataLoader2],
inject_="injected_data_2",
)
def fn_str_inject(injected_data_1: str, injected_data_2: int) -> str:
return "".join([injected_data_1] * injected_data_2)
config = {}
adapter = base.DefaultAdapter()
fg = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(fn_str_inject), config=config, adapter=adapter
)
result = fg.execute(inputs={}, nodes=fg.nodes.values())
assert result["fn_str_inject"] == "foofoo"
assert result["fn_str_inject.load_data.injected_data_1"] == (
"foo",
{"loader": "string_data_loader"},
)
assert result["fn_str_inject.load_data.injected_data_2"] == (
2,
{"loader": "int_data_loader_2"},
)
@pytest.mark.parametrize(
"source_",
[
value("tests/resources/data/test_load_from_data.json"),
source("test_data"),
],
)
def test_load_from_decorator_json_file(source_):
@load_from.json(path=source_)
def raw_json_data(data: Dict[str, Any]) -> Dict[str, Any]:
return data
def number_employees(raw_json_data: Dict[str, Any]) -> int:
return len(raw_json_data["employees"])
def sum_age(raw_json_data: Dict[str, Any]) -> float:
return sum([employee["age"] for employee in raw_json_data["employees"]])
def mean_age(sum_age: float, number_employees: int) -> float:
return sum_age / number_employees
config = {}
dr = driver.Driver(
config,
ad_hoc_utils.create_temporary_module(raw_json_data, number_employees, sum_age, mean_age),
adapter=base.DefaultAdapter(),
)
result = dr.execute(
["mean_age"], inputs={"test_data": "tests/resources/data/test_load_from_data.json"}
)
assert result["mean_age"] - 32.33333 < 0.0001
def test_loader_fails_for_missing_attribute():
with pytest.raises(AttributeError):
load_from.not_a_loader(param=value("foo"))
def test_pandas_extensions_end_to_end(tmp_path_factory):
output_path = str(tmp_path_factory.mktemp("test_pandas_extensions_end_to_end") / "output.csv")
input_path = "tests/resources/data/test_load_from_data.csv"
@save_to.csv(path=source("output_path"), output_name_="save_df")
@load_from.csv(path=source("input_path"))
def df(data: pd.DataFrame) -> pd.DataFrame:
return data
config = {}
dr = driver.Driver(
config,
ad_hoc_utils.create_temporary_module(df),
adapter=base.DefaultAdapter(),
)
# run once to check that loading is correct
result = dr.execute(
["df", "save_df"],
inputs={"input_path": input_path, "output_path": output_path},
)
assert result["df"].shape == (3, 5)
assert result["df"].loc[0, "firstName"] == "John"
#
result_just_read = dr.execute(
["df"],
inputs={"input_path": output_path},
)
# This is just reading the same file we wrote out, so it should be the same
pd.testing.assert_frame_equal(result["df"], result_just_read["df"])
@dataclasses.dataclass
class MarkingSaver(DataSaver):
markers: set
more_markers: set
def save_data(self, data: int) -> Dict[str, Any]:
self.markers.add(data)
self.more_markers.add(data)
return {}
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]
@classmethod
def name(cls) -> str:
return "marker"
def test_save_to_decorator():
def fn() -> int:
return 1
marking_set = set()
marking_set_2 = set()
decorator = SaveToDecorator(
[MarkingSaver],
output_name_="save_fn",
markers=value(marking_set),
more_markers=source("more_markers"),
)
node_ = node.Node.from_fn(fn)
nodes = decorator.transform_node(node_, {}, fn)
assert len(nodes) == 2
nodes_by_name = {node_.name: node_ for node_ in nodes}
assert "save_fn" in nodes_by_name
assert "fn" in nodes_by_name
save_fn_node = nodes_by_name["save_fn"]
assert sorted(save_fn_node.input_types.keys()) == ["fn", "more_markers"]
assert save_fn_node(**{"fn": 1, "more_markers": marking_set_2}) == {}
assert save_fn_node.tags == {
"hamilton.data_saver": True,
"hamilton.data_saver.sink": "marker",
"hamilton.data_saver.classname": MarkingSaver.__qualname__,
}
# Check that the markers are updated, ensuring that the save_fn is called
assert marking_set_2 == {1}
assert marking_set == {1}
@dataclasses.dataclass
class OptionalParamDataLoader(DataLoader):
param: int = 1
def load_data(self, type_: Type[int]) -> Tuple[int, Dict[str, Any]]:
return self.param, {}
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]
@classmethod
def name(cls) -> str:
return "optional"
def test_adapters_optional_params():
@LoadFromDecorator([OptionalParamDataLoader])
def foo(param: int) -> int:
return param
fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(foo),
config={},
adapter=base.DefaultAdapter(),
)
assert len(fg) == 3
assert "foo" in fg
def test_save_to_with_input_from_other_fn():
# This tests that we can refer to another node in save_to
def output_path() -> str:
return "output.json"
@save_to.json(path=source("output_path"), output_name_="save_fn")
def fn() -> dict:
return {"a": 1}
fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(output_path, fn),
config={},
adapter=base.DefaultAdapter(),
)
assert len(fg) == 3
def test_load_from_with_input_from_other_fn():
# This tests that we can refer to another node in load_from
def input_path() -> str:
return "input.json"
@load_from.json(path=source("input_path"))
def fn(data: dict) -> dict:
return data
fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(input_path, fn),
config={},
adapter=base.DefaultAdapter(),
)
assert len(fg) == 4
def test_load_from_with_multiple_inputs():
# This tests that we can refer to another node in load_from
@load_from.json(
path=value("input_1.json"),
inject_="data1",
)
@load_from.json(
path=value("input_2.json"),
inject_="data2",
)
def fn(data1: dict, data2: dict) -> dict:
return {**data1, **data2}
fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(fn),
config={},
adapter=base.DefaultAdapter(),
)
# One filter, one loader for each and the transform function
assert len(fg) == 5