blob: 8d60e50c6b3f5eefb03bdfa25e41342ff80e7c69 [file] [log] [blame]
import inspect
import pathlib
import uuid
from itertools import permutations
import pandas as pd
import pytest
import hamilton.graph_utils
import hamilton.htypes
from hamilton import ad_hoc_utils, base, graph, node
from hamilton.execution import graph_functions
from hamilton.function_modifiers import schema
from hamilton.lifecycle import base as lifecycle_base
from hamilton.node import NodeType
import tests.resources.bad_functions
import tests.resources.compatible_input_types
import tests.resources.config_modifier
import tests.resources.cyclic_functions
import tests.resources.dummy_functions
import tests.resources.extract_column_nodes
import tests.resources.extract_columns_execution_count
import tests.resources.functions_with_generics
import tests.resources.incompatible_input_types
import tests.resources.layered_decorators
import tests.resources.multiple_decorators_together
import tests.resources.optional_dependencies
import tests.resources.parametrized_inputs
import tests.resources.parametrized_nodes
import tests.resources.test_default_args
import tests.resources.typing_vs_not_typing
def test_find_functions():
"""Tests that we filter out _ functions when passed a module and don't pull in anything from the imports."""
expected = [
("A", tests.resources.dummy_functions.A),
("B", tests.resources.dummy_functions.B),
("C", tests.resources.dummy_functions.C),
]
actual = hamilton.graph_utils.find_functions(tests.resources.dummy_functions)
assert len(actual) == len(expected)
assert actual == expected
def test_find_functions_from_temporary_function_module():
"""Tests that we handle the TemporaryFunctionModule object correctly."""
expected = [
("A", tests.resources.dummy_functions.A),
("B", tests.resources.dummy_functions.B),
("C", tests.resources.dummy_functions.C),
]
func_module = ad_hoc_utils.create_temporary_module(
tests.resources.dummy_functions.A,
tests.resources.dummy_functions.B,
tests.resources.dummy_functions.C,
)
actual = hamilton.graph_utils.find_functions(func_module)
assert len(actual) == len(expected)
assert [node_name for node_name, _ in actual] == [node_name for node_name, _ in expected]
assert [fn.__code__ for _, fn in actual] == [
fn.__code__ for _, fn in expected
] # easy way to say they're the same
def test_add_dependency_missing_param_type():
"""Tests case that we error if types are missing from a parameter."""
with pytest.raises(ValueError):
a_sig = inspect.signature(tests.resources.bad_functions.A)
node.Node(
"A", a_sig.return_annotation, "A doc", tests.resources.bad_functions.A
) # should error out
def test_add_dependency_missing_function_type():
"""Tests case that we error if types are missing from a function."""
with pytest.raises(ValueError):
b_sig = inspect.signature(tests.resources.bad_functions.B)
node.Node(
"B", b_sig.return_annotation, "B doc", tests.resources.bad_functions.B
) # should error out
def test_add_dependency_strict_node_dependencies():
"""Tests that we add node dependencies between functions correctly.
Setup here is: B depends on A. So A is depended on by B. B is not depended on by anyone.
"""
b_sig = inspect.signature(tests.resources.dummy_functions.B)
func_node = node.Node("B", b_sig.return_annotation, "B doc", tests.resources.dummy_functions.B)
func_name = "B"
nodes = {
"A": node.Node(
"A",
inspect.signature(tests.resources.dummy_functions.A).return_annotation,
"A doc",
tests.resources.dummy_functions.A,
)
}
param_name = "A"
param_type = b_sig.parameters["A"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
lifecycle_base.LifecycleAdapterSet(),
)
assert nodes["A"] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_add_dependency_input_nodes_mismatch_on_types():
"""Tests that if two functions request an input that has incompatible types, we error out."""
b_sig = inspect.signature(tests.resources.incompatible_input_types.b)
c_sig = inspect.signature(tests.resources.incompatible_input_types.c)
nodes = {
"b": node.Node.from_fn(tests.resources.incompatible_input_types.b),
"c": node.Node.from_fn(tests.resources.incompatible_input_types.c),
}
nodes["b"]._originating_functions = (tests.resources.incompatible_input_types.b,)
nodes["c"]._originating_functions = (tests.resources.incompatible_input_types.c,)
param_name = "a"
# this adds 'a' to nodes
graph.add_dependency(
nodes["b"],
"b",
nodes,
param_name,
b_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
assert "a" in nodes
# adding dependency of c on a should fail because the types are incompatible
with pytest.raises(ValueError):
graph.add_dependency(
nodes["c"],
"c",
nodes,
param_name,
c_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
def test_add_dependency_input_nodes_mismatch_on_types_complex():
"""Tests a more complex scenario we don't support right now with input types."""
e_sig = inspect.signature(tests.resources.incompatible_input_types.e)
f_sig = inspect.signature(tests.resources.incompatible_input_types.f)
nodes = {
"e": node.Node.from_fn(tests.resources.incompatible_input_types.e),
"f": node.Node.from_fn(tests.resources.incompatible_input_types.f),
}
nodes["e"]._originating_functions = (tests.resources.incompatible_input_types.e,)
nodes["f"]._originating_functions = (tests.resources.incompatible_input_types.f,)
param_name = "d"
# this adds 'a' to nodes
graph.add_dependency(
nodes["e"],
"e",
nodes,
param_name,
e_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
assert "d" in nodes
# adding dependency of c on a should fail because the types are incompatible
with pytest.raises(ValueError):
graph.add_dependency(
nodes["e"],
"e",
nodes,
param_name,
f_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
def test_add_dependency_input_nodes_compatible_types():
"""Tests that if functions request an input that we correctly accept compatible types."""
b_sig = inspect.signature(tests.resources.compatible_input_types.b)
c_sig = inspect.signature(tests.resources.compatible_input_types.c)
d_sig = inspect.signature(tests.resources.compatible_input_types.d)
nodes = {
"b": node.Node.from_fn(tests.resources.compatible_input_types.b),
"c": node.Node.from_fn(tests.resources.compatible_input_types.c),
"d": node.Node.from_fn(tests.resources.compatible_input_types.d),
}
nodes["b"]._originating_functions = (tests.resources.compatible_input_types.b,)
nodes["c"]._originating_functions = (tests.resources.compatible_input_types.c,)
nodes["d"]._originating_functions = (tests.resources.compatible_input_types.d,)
# what we want to add
param_name = "a"
# this adds 'a' to nodes
graph.add_dependency(
nodes["b"],
"b",
nodes,
param_name,
b_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
assert "a" in nodes
# this adds 'a' to 'c' as well.
graph.add_dependency(
nodes["c"],
"c",
nodes,
param_name,
c_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
# test that we shrink the type to the tighter type
assert nodes["a"].type == str
graph.add_dependency(
nodes["d"],
"d",
nodes,
param_name,
d_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
def test_add_dependency_input_nodes_compatible_types_order_check():
"""Tests that if functions request an input that we correctly accept compatible types independent of order.
This just reorders test_add_dependency_input_nodes_compatible_types to ensure the outcome does not change.
"""
b_sig = inspect.signature(tests.resources.compatible_input_types.b)
c_sig = inspect.signature(tests.resources.compatible_input_types.c)
d_sig = inspect.signature(tests.resources.compatible_input_types.d)
nodes = {
"b": node.Node.from_fn(tests.resources.compatible_input_types.b),
"c": node.Node.from_fn(tests.resources.compatible_input_types.c),
"d": node.Node.from_fn(tests.resources.compatible_input_types.d),
}
nodes["b"]._originating_functions = (tests.resources.compatible_input_types.b,)
nodes["c"]._originating_functions = (tests.resources.compatible_input_types.c,)
nodes["d"]._originating_functions = (tests.resources.compatible_input_types.d,)
# what we want to add
param_name = "a"
# this adds 'a' to nodes
graph.add_dependency(
nodes["c"],
"c",
nodes,
param_name,
c_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
assert "a" in nodes
assert nodes["a"].type == str
# this adds 'a' to 'c' as well.
graph.add_dependency(
nodes["b"],
"b",
nodes,
param_name,
b_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
# test that type didn't change
assert nodes["a"].type == str
graph.add_dependency(
nodes["d"],
"d",
nodes,
param_name,
d_sig.parameters[param_name].annotation,
lifecycle_base.LifecycleAdapterSet(),
)
def test_typing_to_primitive_conversion():
"""Tests that we can mix function output being typing type, and dependent function using primitive type."""
b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B)
func_node = node.Node(
"B", b_sig.return_annotation, "B doc", tests.resources.typing_vs_not_typing.B
)
func_name = "B"
nodes = {
"A": node.Node(
"A",
inspect.signature(tests.resources.typing_vs_not_typing.A).return_annotation,
"A doc",
tests.resources.typing_vs_not_typing.A,
)
}
param_name = "A"
param_type = b_sig.parameters["A"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
lifecycle_base.LifecycleAdapterSet(),
)
assert nodes["A"] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_primitive_to_typing_conversion():
"""Tests that we can mix function output being a primitive type, and dependent function using typing type."""
b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B2)
func_node = node.Node(
"B2", b_sig.return_annotation, "B2 doc", tests.resources.typing_vs_not_typing.B2
)
func_name = "B2"
nodes = {
"A2": node.Node(
"A2",
inspect.signature(tests.resources.typing_vs_not_typing.A2).return_annotation,
"A2 doc",
tests.resources.typing_vs_not_typing.A2,
)
}
param_name = "A2"
param_type = b_sig.parameters["A2"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
lifecycle_base.LifecycleAdapterSet(),
)
assert nodes["A2"] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_throwing_error_on_incompatible_types():
"""Tests we error on incompatible types."""
d_sig = inspect.signature(tests.resources.bad_functions.D)
func_node = node.Node("D", d_sig.return_annotation, "D doc", tests.resources.bad_functions.D)
func_name = "D"
nodes = {
"C": node.Node(
"C",
inspect.signature(tests.resources.bad_functions.C).return_annotation,
"C doc",
tests.resources.bad_functions.C,
)
}
param_name = "C"
param_type = d_sig.parameters["C"].annotation
with pytest.raises(ValueError):
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
lifecycle_base.LifecycleAdapterSet(),
)
def test_add_dependency_user_nodes():
"""Tests that we add node user defined dependencies correctly.
Setup here is: A depends on b and c. But we're only doing one call. So expecting A having 'b' as a dependency,
and 'b' is depended on by A.
"""
a_sig = inspect.signature(tests.resources.dummy_functions.A)
func_node = node.Node("A", a_sig.return_annotation, "A doc", tests.resources.dummy_functions.A)
func_name = "A"
nodes = {}
param_name = "b"
param_type = a_sig.parameters["b"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
lifecycle_base.LifecycleAdapterSet(),
)
# user node is created and added to nodes.
assert nodes["b"] == func_node.dependencies[0]
assert nodes["b"].depended_on_by[0] == func_node
assert func_node.depended_on_by == []
def create_testing_nodes():
"""Helper function for creating the nodes represented in dummy_functions.py."""
nodes = {
"A": node.Node(
"A",
inspect.signature(tests.resources.dummy_functions.A).return_annotation,
"Function that should become part of the graph - A",
tests.resources.dummy_functions.A,
tags={"module": "tests.resources.dummy_functions"},
),
"B": node.Node(
"B",
inspect.signature(tests.resources.dummy_functions.B).return_annotation,
"Function that should become part of the graph - B",
tests.resources.dummy_functions.B,
tags={"module": "tests.resources.dummy_functions"},
),
"C": node.Node(
"C",
inspect.signature(tests.resources.dummy_functions.C).return_annotation,
"",
tests.resources.dummy_functions.C,
tags={"module": "tests.resources.dummy_functions"},
),
"b": node.Node(
"b",
inspect.signature(tests.resources.dummy_functions.A).parameters["b"].annotation,
node_source=NodeType.EXTERNAL,
),
"c": node.Node(
"c",
inspect.signature(tests.resources.dummy_functions.A).parameters["c"].annotation,
node_source=NodeType.EXTERNAL,
),
}
nodes["A"].dependencies.append(nodes["b"])
nodes["A"].dependencies.append(nodes["c"])
nodes["A"].depended_on_by.append(nodes["B"])
nodes["A"].depended_on_by.append(nodes["C"])
nodes["b"].depended_on_by.append(nodes["A"])
nodes["c"].depended_on_by.append(nodes["A"])
nodes["B"].dependencies.append(nodes["A"])
nodes["C"].dependencies.append(nodes["A"])
return nodes
def test_create_function_graph_simple():
"""Tests that we create a simple function graph."""
expected = create_testing_nodes()
actual = graph.create_function_graph(tests.resources.dummy_functions, config={})
assert actual == expected
def test_execute():
"""Tests graph execution along with basic memoization since A is depended on by two functions."""
nodes = create_testing_nodes()
inputs = {"b": 2, "c": 5}
expected = {"A": 7, "B": 49, "C": 14, "b": 2, "c": 5}
actual = graph_functions.execute_subdag(nodes=nodes.values(), inputs=inputs)
assert actual == expected
actual = graph_functions.execute_subdag(nodes=nodes.values(), inputs=inputs, overrides={"A": 8})
assert actual["A"] == 8
def test_get_required_functions():
"""Exercises getting the subset of the graph for computation on the toy example we have constructed."""
nodes = create_testing_nodes()
final_vars = ["A", "B"]
expected_user_nodes = {nodes["b"], nodes["c"]}
expected_nodes = {nodes["A"], nodes["B"], nodes["b"], nodes["c"]} # we skip 'C'
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
actual_nodes, actual_ud_nodes = fg.get_upstream_nodes(final_vars)
assert actual_nodes == expected_nodes
assert actual_ud_nodes == expected_user_nodes
def test_get_downstream_nodes():
"""Exercises getting the downstream subset of the graph for computation on the toy example we have constructed."""
nodes = create_testing_nodes()
var_changes = ["A"]
expected_nodes = {nodes["B"], nodes["C"], nodes["A"]}
# expected_nodes = {nodes['A'], nodes['B'], nodes['b'], nodes['c']} # we skip 'C'
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
actual_nodes = fg.get_downstream_nodes(var_changes)
assert actual_nodes == expected_nodes
def test_function_graph_from_multiple_sources():
fg = graph.FunctionGraph.from_modules(
tests.resources.dummy_functions, tests.resources.parametrized_nodes, config={}
)
assert len(fg.get_nodes()) == 8 # we take the union of all of them, and want to test that
def test_end_to_end_with_parametrized_nodes():
"""Tests that a simple function graph with parametrized nodes works end-to-end"""
fg = graph.FunctionGraph.from_modules(tests.resources.parametrized_nodes, config={})
results = fg.execute(fg.get_nodes(), {})
assert results == {"parametrized_1": 1, "parametrized_2": 2, "parametrized_3": 3}
def test_end_to_end_with_parametrized_inputs():
fg = graph.FunctionGraph.from_modules(
tests.resources.parametrized_inputs, config={"static_value": 3}
)
results = fg.execute(fg.get_nodes())
assert results == {
"input_1": 1,
"input_2": 2,
"input_3": 3,
"output_1": 1 + 3,
"output_2": 2 + 3,
"output_12": 1 + 2 + 3,
"output_123": 1 + 2 + 3 + 3,
"static_value": 3,
}
def test_get_required_functions_askfor_config():
"""Tests that a simple function graph with parametrized nodes works end-to-end"""
fg = graph.FunctionGraph.from_modules(tests.resources.parametrized_nodes, config={"a": 1})
nodes, user_nodes = fg.get_upstream_nodes(["a", "parametrized_1"])
(n,) = user_nodes
assert n.name == "a"
results = fg.execute(user_nodes)
assert results == {"a": 1}
def test_end_to_end_with_column_extractor_nodes():
"""Tests that a simple function graph with nodes that extract columns works end-to-end"""
fg = graph.FunctionGraph.from_modules(tests.resources.extract_column_nodes, config={})
nodes = fg.get_nodes()
results = fg.execute(nodes, {}, {})
df_expected = tests.resources.extract_column_nodes.generate_df()
pd.testing.assert_series_equal(results["col_1"], df_expected["col_1"])
pd.testing.assert_series_equal(results["col_2"], df_expected["col_2"])
pd.testing.assert_frame_equal(results["generate_df"], df_expected)
assert (
nodes[0].documentation == "Function that should be parametrized to form multiple functions"
)
def test_end_to_end_with_multiple_decorators():
"""Tests that a simple function graph with multiple decorators on a function works end-to-end"""
fg = graph.FunctionGraph.from_modules(
tests.resources.multiple_decorators_together,
config={"param0": 3, "param1": 1, "in_value1": 42, "in_value2": "string_value"},
)
nodes = fg.get_nodes()
# To help debug issues:
# nodez, user_nodes = fg.get_upstream_nodes([n.name for n in nodes],
# {"param0": 3, "param1": 1,
# "in_value1": 42, "in_value2": "string_value"})
# fg.display(
# nodez,
# user_nodes,
# "all_multiple_decorators",
# render_kwargs=None,
# graphviz_kwargs=None,
# )
results = fg.execute(nodes, {}, {})
df_expected = tests.resources.multiple_decorators_together._sum_multiply(3, 1, 2)
dict_expected = tests.resources.multiple_decorators_together._sum(3, 1, 2)
pd.testing.assert_series_equal(results["param1b"], df_expected["param1b"])
pd.testing.assert_frame_equal(results["to_modify"], df_expected)
assert results["total"] == dict_expected["total"]
assert results["to_modify_2"] == dict_expected
node_dict = {n.name: n for n in nodes}
print(sorted(list(node_dict.keys())))
assert (
node_dict["to_modify"].documentation
== "This is a dummy function showing extract_columns with does."
)
assert (
node_dict["to_modify_2"].documentation
== "This is a dummy function showing extract_fields with does."
)
# tag only applies right now to outer most node layer
assert node_dict["uber_decorated_function"].tags == {
"module": "tests.resources.multiple_decorators_together"
} # tags are not propagated
assert node_dict["out_value1"].tags == {
"module": "tests.resources.multiple_decorators_together",
"test_key": "test-value",
}
assert node_dict["out_value2"].tags == {
"module": "tests.resources.multiple_decorators_together",
"test_key": "test-value",
}
def test_end_to_end_with_config_modifier():
config = {
"fn_1_version": 1,
}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results["fn"] == "version_1"
config = {
"fn_1_version": 2,
}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results["fn"] == "version_2"
config = {
"fn_1_version": 3,
}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results["fn"] == "version_3"
def test_non_required_nodes():
fg = graph.FunctionGraph.from_modules(
tests.resources.test_default_args, config={"required": 10}
)
results = fg.execute(
# D is not on the execution path, so it should not break things
[n for n in fg.get_nodes() if n.node_role == NodeType.STANDARD and n.name != "D"],
{},
{},
)
assert results["A"] == 10
fg = graph.FunctionGraph.from_modules(
tests.resources.test_default_args, config={"required": 10, "defaults_to_zero": 1}
)
results = fg.execute(
[n for n in fg.get_nodes() if n.node_role == NodeType.STANDARD],
{},
{},
)
assert results["A"] == 11
assert results["D"] == 2
def test_config_can_override():
config = {"new_param": "new_value"}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
out = fg.execute([n for n in fg.get_nodes()])
assert out["new_param"] == "new_value"
def test_function_graph_has_cycles_true():
"""Tests whether we catch a graph with cycles -- and expected behaviors"""
fg = graph.FunctionGraph.from_modules(tests.resources.cyclic_functions, config={"b": 2, "c": 1})
all_nodes = fg.get_nodes()
nodes = [n for n in all_nodes if not n.user_defined]
user_nodes = [n for n in all_nodes if n.user_defined]
assert fg.has_cycles(nodes, user_nodes) is True
required_nodes, required_user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
assert required_nodes == set(nodes + user_nodes)
assert required_user_nodes == set(user_nodes)
# We don't want to support this behavior officially -- but this works:
# result = fg.execute([n for n in nodes if n.name == 'B'], overrides={'A': 1, 'D': 2})
# assert len(result) == 3
# assert result['B'] == 3
with pytest.raises(
RecursionError
): # throw recursion error when we don't have a way to short circuit
fg.execute([n for n in nodes if n.name == "B"])
def test_function_graph_has_cycles_false():
"""Tests whether we catch a graph with cycles"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
all_nodes = fg.get_nodes()
# checks it two ways
nodes = [n for n in all_nodes if not n.user_defined]
user_nodes = [n for n in all_nodes if n.user_defined]
assert fg.has_cycles(nodes, user_nodes) is False
# this is called by the driver
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
assert fg.has_cycles(nodes, user_nodes) is False
def test_function_graph_display(tmp_path: pathlib.Path):
"""Tests that display saves a file"""
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
all_nodes = set()
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
# for the same reason, order of input nodes are non-deterministic
expected_set = set(
[
'\t\tfunction [fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
"\t\tgraph [fontname=helvetica label=Legend rank=same]\n",
"\t\tinput [fontname=Helvetica margin=0.15 shape=rectangle style=dashed]\n",
'\t\toutput [fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
"\tA -> B\n",
"\tA -> C\n",
'\tA [label=<<b>A</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
'\tB [label=<<b>B</b><br /><br /><i>int</i>> fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
'\tC [label=<<b>C</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
"\t_A_inputs -> A\n",
# commenting out input node: '\t_A_inputs [label=<<table border="0"><tr><td>c</td><td>int</td></tr><tr><td>b</td><td>int</td></tr></table>> fontname=Helvetica margin=0.15 shape=rectangle style=dashed]\n',
"\tgraph [compound=true concentrate=true rankdir=LR ranksep=0.4]\n",
"\tsubgraph cluster__legend {\n",
"\t}\n",
"// Dependency Graph\n",
"digraph {\n",
"}\n",
]
)
fg.display(
all_nodes,
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
node_modifiers=node_modifiers,
)
dot = dot_file_path.open("r").readlines()
dot_set = set(dot)
assert dot_set.issuperset(expected_set) and len(dot_set.difference(expected_set)) == 1
def test_function_graph_display_no_dot_output(tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
fg.display(set(fg.get_nodes()), output_file_path=None)
assert not dot_file_path.exists()
@pytest.mark.parametrize("show_legend", [(True), (False)])
def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
show_legend=show_legend,
)
dot = dot_file_path.open("r").read()
found_legend = "cluster__legend" in dot
assert found_legend is show_legend
@pytest.mark.parametrize("orient", [("LR"), ("TB"), ("RL"), ("BT")])
def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
orient=orient,
)
dot = dot_file_path.open("r").read()
# this could break if a rankdir is given to the legend subgraph
assert f"rankdir={orient}" in dot
@pytest.mark.parametrize("hide_inputs", [(True,), (False,)])
def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
hide_inputs=hide_inputs,
)
dot_lines = dot_file_path.open("r").readlines()
found_input = any(line.startswith("\t_") for line in dot_lines)
assert found_input is not hide_inputs
def test_function_graph_display_without_saving():
"""Tests that display works when None is passed in for path"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
all_nodes = set()
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
digraph = fg.display(all_nodes, output_file_path=None, node_modifiers=node_modifiers)
assert digraph is not None
import graphviz
assert isinstance(digraph, graphviz.Digraph)
@pytest.mark.parametrize("display_fields", [(True,), (False,)])
def test_function_graph_display_fields(display_fields: bool, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
@schema.output(("foo", "int"), ("bar", "float"), ("baz", "str"))
def df_with_schema() -> pd.DataFrame:
pass
mod = ad_hoc_utils.create_temporary_module(df_with_schema)
fg = graph.FunctionGraph.from_modules(mod, config={})
fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
display_fields=display_fields,
)
dot_lines = dot_file_path.open("r").readlines()
if display_fields:
assert any("foo" in line for line in dot_lines)
assert any("bar" in line for line in dot_lines)
assert any("baz" in line for line in dot_lines)
assert any("cluster" in line for line in dot_lines)
else:
assert not any("foo" in line for line in dot_lines)
assert not any("bar" in line for line in dot_lines)
assert not any("baz" in line for line in dot_lines)
assert not any("cluster" in line for line in dot_lines)
def test_function_graph_display_fields_shared_schema(tmp_path: pathlib.Path):
# This ensures an edge case where they end up getting dropped if there are duplicates
dot_file_path = tmp_path / "dag"
SCHEMA = (("foo", "int"), ("bar", "float"), ("baz", "str"))
@schema.output(*SCHEMA)
def df_1_with_schema() -> pd.DataFrame:
pass
@schema.output(*SCHEMA)
def df_2_with_schema() -> pd.DataFrame:
pass
mod = ad_hoc_utils.create_temporary_module(df_1_with_schema, df_2_with_schema)
fg = graph.FunctionGraph.from_modules(mod, config={})
fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
display_fields=True,
)
dot_lines = dot_file_path.open("r").readlines()
def _get_occurances(var: str):
return [item for item in dot_lines if var in item]
# We just need to make sure these show up twice
assert len(_get_occurances("foo=")) == 2
assert len(_get_occurances("bar=")) == 2
assert len(_get_occurances("baz=")) == 2
def test_create_graphviz_graph():
"""Tests that we create a graphviz graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
nodez = nodes.union(user_nodes)
node_modifiers = {
"b": {graph.VisualizationNodeModifiers.IS_USER_INPUT},
"c": {graph.VisualizationNodeModifiers.IS_USER_INPUT},
"B": {graph.VisualizationNodeModifiers.IS_OUTPUT},
}
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
# for the same reason, order of input nodes are non-deterministic
expected_set = set(
[
"// Dependency Graph",
"",
"digraph {",
"\tgraph [compound=true concentrate=true rankdir=LR ranksep=0.4 ratio=1]",
'\tB [label=<<b>B</b><br /><br /><i>int</i>> fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]',
'\tC [label=<<b>C</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]',
'\tA [label=<<b>A</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]',
"\tA -> B",
"\tA -> C",
'\t_A_inputs [label=<<table border="0"><tr><td>b</td><td>int</td></tr><tr><td>b</td><td>int</td></tr></table>> fontname=Helvetica margin=0.15 shape=rectangle style=dashed]',
"\t_A_inputs -> A",
"\tsubgraph cluster__legend {",
"\t\tgraph [fontname=helvetica label=Legend rank=same]",
"\t\tinput [fontname=Helvetica margin=0.15 shape=rectangle style=dashed]",
'\t\tfunction [fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]',
'\t\toutput [fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]',
"\t}",
"}",
"",
]
)
digraph = graph.create_graphviz_graph(
nodez,
"Dependency Graph\n",
graphviz_kwargs=dict(graph_attr={"ratio": "1"}),
node_modifiers=node_modifiers,
strictly_display_only_nodes_passed_in=False,
)
# the HTML table isn't deterministic. Replace the value in it with a single one.
dot_set = set(str(digraph).replace("<td>c</td>", "<td>b</td>").split("\n"))
assert dot_set == expected_set
def test_create_networkx_graph():
"""Tests that we create a networkx graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
digraph = graph.create_networkx_graph(nodes, user_nodes, "test-graph")
expected_nodes = sorted(["c", "B", "C", "b", "A"])
expected_edges = sorted([("c", "A"), ("b", "A"), ("A", "B"), ("A", "C")])
assert sorted(list(digraph.nodes)) == expected_nodes
assert sorted(list(digraph.edges)) == expected_edges
def test_end_to_end_with_layered_decorators_resolves_true():
fg = graph.FunctionGraph.from_modules(
tests.resources.layered_decorators, config={"foo": "bar", "d": 10, "b": 20}
)
out = fg.execute([n for n in fg.get_nodes()])
assert len(out) > 0 # test config.when resolves correctly
assert out["e"] == (20 + 10)
assert out["f"] == (20 + 20)
def test_end_to_end_with_layered_decorators_resolves_false():
config = {"foo": "not_bar", "d": 10, "b": 20}
fg = graph.FunctionGraph.from_modules(tests.resources.layered_decorators, config=config)
out = fg.execute(
[n for n in fg.get_nodes()],
)
assert {item: value for item, value in out.items() if item not in config} == {}
def test_combine_inputs_no_collision():
"""Tests the combine_and_validate_inputs functionality when there are no collisions"""
combined = graph_functions.combine_config_and_inputs({"a": 1}, {"b": 2})
assert combined == {"a": 1, "b": 2}
def test_combine_inputs_collision():
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys but not values"""
with pytest.raises(ValueError):
graph_functions.combine_config_and_inputs({"a": 1}, {"a": 2})
def test_combine_inputs_collision_2():
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys and values"""
with pytest.raises(ValueError):
graph_functions.combine_config_and_inputs({"a": 1}, {"a": 1})
def test_extract_columns_executes_once():
"""Ensures that extract_columns only computes the function once.
Note this is a bit heavy-handed of a test but its nice to have."""
fg = graph.FunctionGraph.from_modules(
tests.resources.extract_columns_execution_count, config={}
)
unique_id = str(uuid.uuid4())
fg.execute([n for n in fg.get_nodes()], inputs={"unique_id": unique_id})
assert (
len(tests.resources.extract_columns_execution_count.outputs[unique_id]) == 1
) # It should only be called once
def test_end_to_end_with_generics():
fg = graph.FunctionGraph.from_modules(
tests.resources.functions_with_generics, config={"b": {}, "c": 1}
)
results = fg.execute(fg.get_nodes())
assert results == {
"A": ({}, 1),
"B": [({}, 1), ({}, 1)],
"C": {"foo": [({}, 1), ({}, 1)]},
"D": 1.0,
"b": {},
"c": 1,
}
@pytest.mark.parametrize(
"config,inputs,overrides",
[
# testing with no provided inputs
({}, {}, {}),
# testing with just configs
({"a": 11}, {}, {}),
({"b": 13, "a": 17}, {}, {}),
({"b": 19, "a": 23, "d": 29, "f": 31}, {}, {}),
# Testing with just inputs
({}, {"a": 37}, {}),
({}, {"b": 41, "a": 43}, {}),
({}, {"b": 41, "a": 43, "d": 47, "f": 53}, {}),
# Testing with just overrides
# TBD whether these should be legitimate -- can we override required inputs?
# Test works now but not part of the contract
# ({}, {}, {'a': 59}),
# ({}, {}, {'a': 61, 'b': 67}),
# ({}, {}, {'a': 71, 'b': 73, 'd': 79, 'f': 83}),
# testing with a mix
({"a": 89}, {"b": 97}, {}),
({"a": 101}, {"b": 103, "d": 107}, {}),
],
)
def test_optional_execute(config, inputs, overrides):
"""Tests execution of optionals with different assortment of overrides, configs, inputs, etc...
Be careful adding tests with conflicting values between them.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config=config)
# we put a user input node first to ensure that order does not matter with computation order.
results = fg.execute([fg.nodes["b"], fg.nodes["g"]], inputs=inputs, overrides=overrides)
do_all_args = {key + "_val": val for key, val in {**config, **inputs, **overrides}.items()}
expected_results = tests.resources.optional_dependencies._do_all(**do_all_args)
assert results["g"] == expected_results["g"]
def test_optional_input_behavior():
"""Tests that if we request optional user inputs that are not provided, we do not break. And if they are
we do the right thing and return them.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
# nothing passed, so nothing returned
result = fg.execute([fg.nodes["b"], fg.nodes["a"]], inputs={}, overrides={})
assert result == {}
# something passed, something returned via config
fg2 = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={"a": 10})
result = fg2.execute([fg.nodes["b"], fg.nodes["a"]], inputs={}, overrides={})
assert result == {"a": 10}
# something passed, something returned via inputs
result = fg.execute([fg.nodes["b"], fg.nodes["a"]], inputs={"a": 10}, overrides={})
assert result == {"a": 10}
# something passed, something returned via overrides
result = fg.execute([fg.nodes["b"], fg.nodes["a"]], inputs={}, overrides={"a": 10})
assert result == {"a": 10}
@pytest.mark.parametrize("node_order", list(permutations("fhi")))
def test_user_input_breaks_if_required_missing(node_order):
"""Tests that we break because `h` is required but is not passed in."""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
permutation = [fg.nodes[n] for n in node_order]
with pytest.raises(NotImplementedError):
fg.execute(permutation, inputs={}, overrides={})
@pytest.mark.parametrize("node_order", list(permutations("fhi")))
def test_user_input_does_not_break_if_required_provided(node_order):
"""Tests that things work no matter the order because `h` is required and is passed in, while `f` is optional."""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={"h": 10})
permutation = [fg.nodes[n] for n in node_order]
result = fg.execute(permutation, inputs={}, overrides={})
assert result == {"h": 10, "i": 17}
def test_optional_donot_drop_none():
"""We do not want to drop `None` results from functions. We want to pass them through to the function.
This is here to enshrine the current behavior.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={"h": None})
# enshrine behavior that None is not removed from being passed to the function.
results = fg.execute([fg.nodes["h"], fg.nodes["i"]], inputs={}, overrides={})
assert results == {"h": None, "i": 17}
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
results = fg.execute(
[fg.nodes["j"], fg.nodes["none_result"], fg.nodes["f"]], inputs={}, overrides={}
)
assert results == {"j": None, "none_result": None} # f omitted cause it's optional.
def test_optional_get_required_compile_time():
"""Tests that getting required with optionals at compile time returns everything
TODO -- change this to be testing a different function (compile time) than runtime.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(["g"])
assert len(all_upstream) == 7 # 6 total nodes upstream
assert len(user_required) == 4 # 4 nodes required input
def test_optional_get_required_runtime():
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(["g"], runtime_inputs={}) # Nothng required
assert len(all_upstream) == 3 # 6 total nodes upstream
assert len(user_required) == 0 # 4 nodes required input
def test_optional_get_required_runtime_with_provided():
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(
["g"], runtime_inputs={"b": 109}
) # Nothng required
assert len(all_upstream) == 4 # 6 total nodes upstream
assert len(user_required) == 1 # 4 nodes required input
def test_in_driver_function_definitions():
"""Tests that we can instantiate a DAG with a function defined in the driver, e.g. notebook context"""
def my_function(A: int, b: int, c: int) -> int:
"""Function for input below"""
return A + b + c
f_module = ad_hoc_utils.create_temporary_module(my_function)
fg = graph.FunctionGraph.from_modules(
tests.resources.dummy_functions, f_module, config={"b": 3, "c": 1}
)
results = fg.execute([n for n in fg.get_nodes() if n.name in ["my_function", "A"]])
assert results == {"A": 4, "b": 3, "c": 1, "my_function": 8}
def test_update_dependencies():
nodes = create_testing_nodes()
new_nodes = graph.update_dependencies(
nodes, lifecycle_base.LifecycleAdapterSet(base.DefaultAdapter())
)
for node_name, node_ in new_nodes.items():
assert node_.dependencies == nodes[node_name].dependencies
assert node_.depended_on_by == nodes[node_name].depended_on_by