blob: cf7162f90e91ebd87f9c50e8c34d3f435a91d34a [file] [log] [blame]
import dataclasses
from typing import Any, Collection, Dict, List, Optional, Tuple, Type
import pytest
from hamilton import base, graph, node, registry
from hamilton.function_modifiers import load_from, save_to, value
from hamilton.io import materialization
from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.io.materialization import (
Extract,
ExtractorFactory,
Materialize,
MaterializerFactory,
_ExtractorFactoryProtocol,
_MaterializerFactoryProtocol,
from_,
to,
)
from hamilton.lifecycle import base as lifecycle_base
import tests.resources.cyclic_functions
import tests.resources.test_default_args
global_mock_data_saver_cache = {}
@dataclasses.dataclass
class MockDataSaver(DataSaver):
storage_key: str
other_storage_key: Optional[str] = None
def save_data(self, data: Any) -> Dict[str, Any]:
global_mock_data_saver_cache[self.storage_key] = data
if self.other_storage_key is not None:
global_mock_data_saver_cache[self.other_storage_key] = data
return {"saved": True}
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [dict]
@classmethod
def name(cls) -> str:
return "mock_for_testing"
@dataclasses.dataclass
class MockDataLoader(DataLoader):
fixed_data: Any
def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
return self.fixed_data, {}
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [dict]
@classmethod
def name(cls) -> str:
return "mock_for_testing"
class JoinBuilder(base.ResultMixin):
@staticmethod
def build_result(**outputs: Dict[str, Any]) -> Any:
out = {}
for output in outputs.values():
out.update(output)
return out
def output_type(self) -> Type:
return dict
def input_types(self) -> List[Type]:
return [dict]
def test_materialization_dynamic_property_access():
json_materializer = Materialize.json
assert isinstance(
json_materializer, _MaterializerFactoryProtocol
) # It should produce a factory function
def test_extraction_dynamic_property_access():
json_extractor = Extract.json
assert isinstance(
json_extractor, _ExtractorFactoryProtocol
) # It should produce a factory function
def test_materializer_factory_generates_nodes_no_builder():
factory = MaterializerFactory(
"test_materializer",
[MockDataSaver],
dependencies=["only_node"],
result_builder=None,
storage_key="test_materializer_factory_generates_nodes_no_builder",
)
def only_node() -> dict:
return {"test_materializer_factory_generates_nodes_no_builder": "ran_correctly"}
base_node = node.Node.from_fn(only_node)
fn_graph = graph.FunctionGraph({base_node.name: base_node}, {})
nodes = factory.generate_nodes(fn_graph)
assert len(nodes) == 1 # No builder node
(node_,) = nodes
# Call, test the side effect as well as the ret val
res = node_(only_node=only_node())
assert res == {"saved": True}
assert (
global_mock_data_saver_cache["test_materializer_factory_generates_nodes_no_builder"]
== only_node()
)
def test_extractor_factory_generates_nodes():
factory = ExtractorFactory(
"input_data",
loaders=[MockDataLoader],
fixed_data=value({"loaded": True}),
)
def test(input_data: dict) -> dict:
return {"loaded_value": input_data}
base_node = node.Node.from_fn(test)
nodes_without_dependencies = graph.update_dependencies(
{base_node.name: base_node}, lifecycle_base.LifecycleAdapterSet(base.DefaultAdapter())
)
fn_graph = graph.FunctionGraph(nodes_without_dependencies, {})
nodes = factory.generate_nodes(fn_graph)
nodes_by_name = {node_.name: node_ for node_ in nodes}
assert "input_data" in nodes_by_name
input_data_node = nodes_by_name["input_data"]
assert input_data_node.type == dict # From above
def test_materializer_factory_generates_nodes_with_builder():
factory = MaterializerFactory(
"test_materializer",
[MockDataSaver],
dependencies=["first_node", "second_node"],
result_builder=JoinBuilder(),
storage_key="test_materializer_factory_generates_nodes_with_builder",
other_storage_key="test_materializer_factory_generates_nodes_with_builder_second_store",
)
def first_node() -> dict:
return {"test_materializer_factory_generates_nodes_no_builder": "ran_correctly"}
def second_node() -> dict:
return {"test_materializer_factory_generates_nodes_no_builder_still": "ran_correctly"}
base_node_0 = node.Node.from_fn(first_node)
base_node_1 = node.Node.from_fn(second_node)
fn_graph = graph.FunctionGraph(
{"first_node": base_node_0, "second_node": base_node_1}, {}, None
)
nodes = factory.generate_nodes(fn_graph)
assert len(nodes) == 2 # One builder node
nodes_by_name = {node_.name: node_ for node_ in nodes}
# Call, test the side effect as well as the ret val
materializer = nodes_by_name.pop("test_materializer") # This one has a defined name
res = materializer(test_materializer_build_result={**first_node(), **second_node()})
assert res == {"saved": True}
assert global_mock_data_saver_cache[
"test_materializer_factory_generates_nodes_with_builder"
] == {**first_node(), **second_node()}
(joiner,) = nodes_by_name.values()
assert joiner(first_node=first_node(), second_node=second_node()) == {
**first_node(),
**second_node(),
}
assert global_mock_data_saver_cache[
"test_materializer_factory_generates_nodes_with_builder_second_store"
] == {**first_node(), **second_node()}
def test_modify_function_graph_materializers():
factory_1 = MaterializerFactory(
"materializer_1",
[MockDataSaver],
dependencies=["first_node", "second_node"],
result_builder=JoinBuilder(),
storage_key="test_modify_function_graph_2",
)
factory_2 = MaterializerFactory(
"materializer_2",
[MockDataSaver],
dependencies=["first_node", "second_node"],
result_builder=JoinBuilder(),
storage_key="test_modify_function_graph_1",
)
def first_node() -> dict:
return {"test_materializer_factory_generates_nodes_no_builder": "ran_correctly"}
def second_node() -> dict:
return {"test_materializer_factory_generates_nodes_no_builder_still": "ran_correctly"}
base_node_0 = node.Node.from_fn(first_node)
base_node_1 = node.Node.from_fn(second_node)
fn_graph = graph.FunctionGraph(
{"first_node": base_node_0, "second_node": base_node_1}, {}, None
)
fn_graph_modified = materialization.modify_graph(fn_graph, [factory_1, factory_2], [])
assert "materializer_1" in fn_graph_modified.nodes
assert "materializer_2" in fn_graph_modified.nodes
assert "first_node" in fn_graph_modified.nodes
assert "second_node" in fn_graph_modified.nodes
# TODO -- add loaders in
def test_modify_function_graph_with_extractor_factories():
factory_1 = ExtractorFactory(
"input_data_1", [MockDataLoader], fixed_data={"test_extractor_factory_1": "ran_correctly"}
)
factory_2 = ExtractorFactory(
"input_data_2", [MockDataLoader], fixed_data={"test_extractor_factory_2": "ran_correctly"}
)
def first_node(input_data_1: dict) -> dict:
return {"loaded_result": input_data_1}
def second_node(input_data_2: dict) -> dict:
return {"loaded_result": input_data_2}
base_node_0 = node.Node.from_fn(first_node)
base_node_1 = node.Node.from_fn(second_node)
fn_graph = graph.FunctionGraph(
{"first_node": base_node_0, "second_node": base_node_1}, {}, None
)
fn_graph_modified = materialization.modify_graph(fn_graph, [], [factory_1, factory_2])
assert "input_data_1" in fn_graph_modified.nodes
assert "input_data_2" in fn_graph_modified.nodes
res = fn_graph_modified.execute(
nodes=[fn_graph_modified.nodes["first_node"], fn_graph_modified.nodes["second_node"]]
)
assert res["input_data_1"] == {"test_extractor_factory_1": "ran_correctly"}
assert res["input_data_2"] == {"test_extractor_factory_2": "ran_correctly"}
assert res["first_node"] == {"loaded_result": {"test_extractor_factory_1": "ran_correctly"}}
assert res["second_node"] == {"loaded_result": {"test_extractor_factory_2": "ran_correctly"}}
def test_modify_function_graph_with_extractor_factories_override():
"""Tests that if we use an injector as an override, its gets run, and the node its replacing does not"""
factory = ExtractorFactory(
"value_to_override", [MockDataLoader], fixed_data={"overwritten_result": True}
)
ran = False
def value_to_override() -> dict:
nonlocal ran
ran = True
return {"overwritten_result": False}
base_node_0 = node.Node.from_fn(value_to_override)
fn_graph = graph.FunctionGraph({base_node_0.name: base_node_0}, {}, None)
fn_graph_modified = materialization.modify_graph(fn_graph, [], [factory])
assert "value_to_override" in fn_graph_modified.nodes
assert (
len(fn_graph_modified.nodes[base_node_0.name].input_types) > 0
) # It actually has some as its a loader
res = fn_graph_modified.execute(nodes=[fn_graph_modified.nodes["value_to_override"]])
assert res["value_to_override"] == {"overwritten_result": True}
assert ran is False
def test_sanitize_materializer_dependencies_happy():
"""Tests that we return new objects & appropriately sanitize dependency types - converting them as necessary."""
factory_1 = MaterializerFactory(
"materializer_1",
[MockDataSaver],
dependencies=[
tests.resources.test_default_args.A,
tests.resources.test_default_args.B,
"C",
],
result_builder=JoinBuilder(),
storage_key="test_modify_function_graph_2",
)
s = {tests.resources.test_default_args.__name__}
actual = factory_1.sanitize_dependencies(s)
assert actual.id == factory_1.id
assert actual.savers == factory_1.savers
assert actual.result_builder == factory_1.result_builder
assert actual.dependencies == ["A", "B", "C"]
assert actual is not factory_1
def test_sanitize_materializer_dependencies_error():
"""Tests that we error when bad cases are encountered."""
factory_1 = MaterializerFactory(
"materializer_1",
[MockDataSaver],
dependencies=["B", tests.resources.cyclic_functions.A],
result_builder=JoinBuilder(),
storage_key="test_modify_function_graph_2",
)
with pytest.raises(ValueError):
s = {tests.resources.test_default_args.__name__}
factory_1.sanitize_dependencies(s)
def test_dynamic_properties_can_be_registered_after_import_for_saver():
@dataclasses.dataclass
class CustomDataSaver(DataSaver):
def save_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
return "value", {}
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [dict]
@classmethod
def name(cls) -> str:
return "testing_unique_key_saver"
registry.register_adapter(CustomDataSaver)
materialize_property = Materialize.testing_unique_key_saver
to_property = to.testing_unique_key_saver
load_from_property = save_to.testing_unique_key_saver
assert materialize_property is not None
assert to_property is not None
assert load_from_property is not None
def test_dynamic_properties_can_be_registered_after_import_for_loader():
@dataclasses.dataclass
class CustomDataLoader(DataLoader):
def load_data(self, type_: Type[int]) -> Tuple[Type[int], Dict[str, Any]]:
return int, {}
@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]
@classmethod
def name(cls) -> str:
return "testing_unique_key_loader"
registry.register_adapter(CustomDataLoader)
extract_property = Extract.testing_unique_key_loader
to_property = from_.testing_unique_key_loader
load_from_property = load_from.testing_unique_key_loader
assert extract_property is not None
assert to_property is not None
assert load_from_property is not None