blob: 40f89841389773865deb0881246255e43cb15aa1 [file] [log] [blame]
import collections
import random
import pytest
import tests.resources.reuse_subdag
from hamilton import ad_hoc_utils, graph
from hamilton.experimental.decorators import reuse
from hamilton.function_modifiers import base, config, value
from hamilton.function_modifiers.dependencies import source
def test_collect_function_fns():
val = random.randint(0, 100000)
def test_fn(out: int = val) -> int:
return out
assert reuse.reuse_functions.collect_functions(load_from=[test_fn])[0]() == test_fn()
def test_collect_functions_module():
val = random.randint(0, 100000)
def test_fn(out: int = val) -> int:
return out
assert (
reuse.reuse_functions.collect_functions(
load_from=[ad_hoc_utils.create_temporary_module(test_fn)]
)[0]()
== test_fn()
)
def test_assign_namespaces():
assert reuse.assign_namespace(node_name="foo", namespace="bar") == "bar.foo"
def foo(a: int) -> int:
return a
@config.when_not(some_config_param=True)
def bar(b: int) -> int:
return b
@config.when(some_config_param=True)
def bar__alt() -> int:
return 10
def test_reuse_subdag_validate_outputs_succeeds():
def test() -> reuse.MultiOutput(foo_result=int, bar_result=str):
pass
decorator = reuse.reuse_functions(
with_inputs={},
namespace="baz",
outputs={"foo": "foo_result", "bar": "bar_result"},
with_config={},
load_from=[foo, bar],
)
decorator.validate(test)
def test_reuse_subdag_validate_output_incorrect_type():
def test() -> int:
pass
decorator = reuse.reuse_functions(
with_inputs={},
namespace="baz",
outputs={"foo": "foo_result", "bar": "bar_result"},
with_config={},
load_from=[foo, bar],
)
with pytest.raises(base.InvalidDecoratorException):
decorator.validate(test)
def test_reuse_subdag_validate_output_fails_types_not_provided():
def test() -> reuse.MultiOutput(foo_result=int):
pass
decorator = reuse.reuse_functions(
with_inputs={},
namespace="baz",
outputs={"foo": "foo_result", "bar": "bar_result"},
with_config={},
load_from=[foo, bar],
)
with pytest.raises(base.InvalidDecoratorException):
decorator.validate(test)
def test_reuse_subdag_basic_no_parameterization():
def test() -> reuse.MultiOutput(foo_result=int, bar_result=int):
pass
decorator = reuse.reuse_functions(
with_inputs={},
namespace="baz",
outputs={"foo": "foo_result", "bar": "bar_result"},
with_config={},
load_from=[foo, bar],
)
nodes = {node_.name: node_ for node_ in decorator.generate_nodes(test, {})}
# subdags have prefixed names
assert "baz.foo" in nodes
assert "baz.bar" in nodes
# but we expect our outputs to exist as well
assert "bar_result" in nodes
assert "foo_result" in nodes
def test_reuse_subdag_basic_simple_parameterization():
def test() -> reuse.MultiOutput(foo_result=int, bar_result=int):
pass
decorator = reuse.reuse_functions(
with_inputs={"a": value(1), "b": value(2)},
namespace="baz",
outputs={"foo": "foo_result", "bar": "bar_result"},
with_config={},
load_from=[foo, bar],
)
nodes = {node_.name: node_ for node_ in decorator.generate_nodes(test, {})}
# This doesn't necessarily have to be part of the contract, but we're testing now to ensure that it works
assert "baz.a" in nodes
assert nodes["baz.a"]() == 1
assert "baz.b" in nodes
assert nodes["baz.b"]() == 2
def test_reuse_subdag_basic_source_parameterization():
def test() -> reuse.MultiOutput(foo_result=int, bar_result=int):
pass
decorator = reuse.reuse_functions(
with_inputs={"a": source("c"), "b": source("d")},
namespace="baz",
outputs={"foo": "foo_result", "bar": "bar_result"},
with_config={},
load_from=[foo, bar],
)
nodes = {node_.name: node_ for node_ in decorator.generate_nodes(test, {})}
# These aren't entirely part of the contract, but they're required for the
# way we're currently implementing it. See https://github.com/stitchfix/hamilton/issues/201
assert "baz.a" in nodes
assert nodes["baz.a"](c=1) == 1
assert "baz.b" in nodes
assert nodes["baz.b"](d=2) == 2
def test_reuse_subdag_handles_config_assignment():
def test() -> reuse.MultiOutput(foo_result=int, bar_result=int):
pass
decorator = reuse.reuse_functions(
with_inputs={"a": value(1)},
namespace="baz",
outputs={"foo": "foo_result", "bar": "bar_result"},
with_config={"some_config_param": True},
load_from=[foo, bar, bar__alt],
)
nodes = {node_.name: node_ for node_ in decorator.generate_nodes(test, {})}
assert nodes["baz.bar"]() == 10
def test_reuse_subdag_end_to_end():
fg = graph.FunctionGraph(tests.resources.reuse_subdag, config={"op": "subtract"})
prefixless_nodes = []
prefixed_nodes = collections.defaultdict(list)
for name, node in fg.nodes.items():
name_split = name.split(".")
if len(name_split) == 1:
prefixless_nodes.append(node)
else:
namespace, name = name_split
prefixed_nodes[namespace].append(node)
node_set = set(fg.nodes)
assert {
"e_1",
"e_2",
"e_3",
"e_4",
"f_1",
"f_2",
"f_3",
"f_4",
} - node_set == set() # All the nodes outputted by our subdags
assert {
"v1.d",
"v2.d",
"v3.d",
"v4.d",
"v1.e",
"v2.e",
"v3.e",
"v4.e",
"v1.f",
"v2.f",
"v3.f",
"v4.f",
} - node_set == set() # All these nodes must be in the DAG -- they're all the namespaced nodes
assert {
"v1.e",
"v2.e",
"v3.e",
"v4.e",
"v1.f",
"v2.f",
"v3.f",
"v4.f",
} - node_set == set() # All these nodes must be in the DAG
assert {"a", "b"} - node_set == set() # common nodes shared by the DAG, not as part of subdags
assert {"e", "f", "d"} - node_set == {
"e",
"f",
"d",
} # We've defined a node e, f, and d, but they're only namespaced in subdags
assert {
"v1.c",
"v2.c",
"v3.c",
} - node_set == set() # These are all static values that are namespaced in the subDAGs
# The following are all static values
assert fg.nodes["v1.c"].callable() == 10
assert fg.nodes["v2.c"].callable() == 20
assert fg.nodes["v3.c"].callable() == 30
# Source assigned this
assert list(fg.nodes["v4.e"].input_types)[0] == "v4.c"
assert list(fg.nodes["v4.c"].input_types)[0] == "b"
assert fg.nodes["v4.c"].callable(b=1234) == 1234
# # Check that the config is assigned and overwritten correctly
assert fg.nodes["v1.d"].callable(**{"v1.c": 10, "a": 100}) == 100 - 10
assert fg.nodes["v3.d"].callable(**{"v3.c": 10, "a": 100}) == 100 + 10
res = fg.execute(nodes=[fg.nodes["sum_everything"]])
assert res["sum_everything"] == 318