blob: 96d90850628428b67632d2210451204d2aa423d1 [file] [log] [blame]
import inspect
import tempfile
import uuid
from itertools import permutations
import pandas as pd
import pytest
import hamilton.graph_utils
import hamilton.htypes
import tests.resources.bad_functions
import tests.resources.config_modifier
import tests.resources.cyclic_functions
import tests.resources.dummy_functions
import tests.resources.extract_column_nodes
import tests.resources.extract_columns_execution_count
import tests.resources.functions_with_generics
import tests.resources.layered_decorators
import tests.resources.multiple_decorators_together
import tests.resources.optional_dependencies
import tests.resources.parametrized_inputs
import tests.resources.parametrized_nodes
import tests.resources.typing_vs_not_typing
from hamilton import ad_hoc_utils, base, graph, node
from hamilton.execution import graph_functions
from hamilton.node import NodeType
def test_find_functions():
"""Tests that we filter out _ functions when passed a module and don't pull in anything from the imports."""
expected = [
("A", tests.resources.dummy_functions.A),
("B", tests.resources.dummy_functions.B),
("C", tests.resources.dummy_functions.C),
]
actual = hamilton.graph_utils.find_functions(tests.resources.dummy_functions)
assert len(actual) == len(expected)
assert actual == expected
def test_find_functions_from_temporary_function_module():
"""Tests that we handle the TemporaryFunctionModule object correctly."""
expected = [
("A", tests.resources.dummy_functions.A),
("B", tests.resources.dummy_functions.B),
("C", tests.resources.dummy_functions.C),
]
func_module = ad_hoc_utils.create_temporary_module(
tests.resources.dummy_functions.A,
tests.resources.dummy_functions.B,
tests.resources.dummy_functions.C,
)
actual = hamilton.graph_utils.find_functions(func_module)
assert len(actual) == len(expected)
assert [node_name for node_name, _ in actual] == [node_name for node_name, _ in expected]
assert [fn.__code__ for _, fn in actual] == [
fn.__code__ for _, fn in expected
] # easy way to say they're the same
def test_add_dependency_missing_param_type():
"""Tests case that we error if types are missing from a parameter."""
with pytest.raises(ValueError):
a_sig = inspect.signature(tests.resources.bad_functions.A)
node.Node(
"A", a_sig.return_annotation, "A doc", tests.resources.bad_functions.A
) # should error out
def test_add_dependency_missing_function_type():
"""Tests case that we error if types are missing from a function."""
with pytest.raises(ValueError):
b_sig = inspect.signature(tests.resources.bad_functions.B)
node.Node(
"B", b_sig.return_annotation, "B doc", tests.resources.bad_functions.B
) # should error out
def test_add_dependency_strict_node_dependencies():
"""Tests that we add node dependencies between functions correctly.
Setup here is: B depends on A. So A is depended on by B. B is not depended on by anyone.
"""
b_sig = inspect.signature(tests.resources.dummy_functions.B)
func_node = node.Node("B", b_sig.return_annotation, "B doc", tests.resources.dummy_functions.B)
func_name = "B"
nodes = {
"A": node.Node(
"A",
inspect.signature(tests.resources.dummy_functions.A).return_annotation,
"A doc",
tests.resources.dummy_functions.A,
)
}
param_name = "A"
param_type = b_sig.parameters["A"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
base.SimplePythonDataFrameGraphAdapter(),
)
assert nodes["A"] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_typing_to_primitive_conversion():
"""Tests that we can mix function output being typing type, and dependent function using primitive type."""
b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B)
func_node = node.Node(
"B", b_sig.return_annotation, "B doc", tests.resources.typing_vs_not_typing.B
)
func_name = "B"
nodes = {
"A": node.Node(
"A",
inspect.signature(tests.resources.typing_vs_not_typing.A).return_annotation,
"A doc",
tests.resources.typing_vs_not_typing.A,
)
}
param_name = "A"
param_type = b_sig.parameters["A"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
base.SimplePythonDataFrameGraphAdapter(),
)
assert nodes["A"] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_primitive_to_typing_conversion():
"""Tests that we can mix function output being a primitive type, and dependent function using typing type."""
b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B2)
func_node = node.Node(
"B2", b_sig.return_annotation, "B2 doc", tests.resources.typing_vs_not_typing.B2
)
func_name = "B2"
nodes = {
"A2": node.Node(
"A2",
inspect.signature(tests.resources.typing_vs_not_typing.A2).return_annotation,
"A2 doc",
tests.resources.typing_vs_not_typing.A2,
)
}
param_name = "A2"
param_type = b_sig.parameters["A2"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
base.SimplePythonDataFrameGraphAdapter(),
)
assert nodes["A2"] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_throwing_error_on_incompatible_types():
"""Tests we error on incompatible types."""
d_sig = inspect.signature(tests.resources.bad_functions.D)
func_node = node.Node("D", d_sig.return_annotation, "D doc", tests.resources.bad_functions.D)
func_name = "D"
nodes = {
"C": node.Node(
"C",
inspect.signature(tests.resources.bad_functions.C).return_annotation,
"C doc",
tests.resources.bad_functions.C,
)
}
param_name = "C"
param_type = d_sig.parameters["C"].annotation
with pytest.raises(ValueError):
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
base.SimplePythonDataFrameGraphAdapter(),
)
def test_add_dependency_user_nodes():
"""Tests that we add node user defined dependencies correctly.
Setup here is: A depends on b and c. But we're only doing one call. So expecting A having 'b' as a dependency,
and 'b' is depended on by A.
"""
a_sig = inspect.signature(tests.resources.dummy_functions.A)
func_node = node.Node("A", a_sig.return_annotation, "A doc", tests.resources.dummy_functions.A)
func_name = "A"
nodes = {}
param_name = "b"
param_type = a_sig.parameters["b"].annotation
graph.add_dependency(
func_node,
func_name,
nodes,
param_name,
param_type,
base.SimplePythonDataFrameGraphAdapter(),
)
# user node is created and added to nodes.
assert nodes["b"] == func_node.dependencies[0]
assert nodes["b"].depended_on_by[0] == func_node
assert func_node.depended_on_by == []
def create_testing_nodes():
"""Helper function for creating the nodes represented in dummy_functions.py."""
nodes = {
"A": node.Node(
"A",
inspect.signature(tests.resources.dummy_functions.A).return_annotation,
"Function that should become part of the graph - A",
tests.resources.dummy_functions.A,
tags={"module": "tests.resources.dummy_functions"},
),
"B": node.Node(
"B",
inspect.signature(tests.resources.dummy_functions.B).return_annotation,
"Function that should become part of the graph - B",
tests.resources.dummy_functions.B,
tags={"module": "tests.resources.dummy_functions"},
),
"C": node.Node(
"C",
inspect.signature(tests.resources.dummy_functions.C).return_annotation,
"",
tests.resources.dummy_functions.C,
tags={"module": "tests.resources.dummy_functions"},
),
"b": node.Node(
"b",
inspect.signature(tests.resources.dummy_functions.A).parameters["b"].annotation,
node_source=NodeType.EXTERNAL,
),
"c": node.Node(
"c",
inspect.signature(tests.resources.dummy_functions.A).parameters["c"].annotation,
node_source=NodeType.EXTERNAL,
),
}
nodes["A"].dependencies.append(nodes["b"])
nodes["A"].dependencies.append(nodes["c"])
nodes["A"].depended_on_by.append(nodes["B"])
nodes["A"].depended_on_by.append(nodes["C"])
nodes["b"].depended_on_by.append(nodes["A"])
nodes["c"].depended_on_by.append(nodes["A"])
nodes["B"].dependencies.append(nodes["A"])
nodes["C"].dependencies.append(nodes["A"])
return nodes
def test_create_function_graph_simple():
"""Tests that we create a simple function graph."""
expected = create_testing_nodes()
actual = graph.create_function_graph(
tests.resources.dummy_functions, config={}, adapter=base.SimplePythonDataFrameGraphAdapter()
)
assert actual == expected
def test_execute():
"""Tests graph execution along with basic memoization since A is depended on by two functions."""
adapter = base.SimplePythonDataFrameGraphAdapter()
nodes = create_testing_nodes()
inputs = {"b": 2, "c": 5}
expected = {"A": 7, "B": 49, "C": 14, "b": 2, "c": 5}
actual = graph_functions.execute_subdag(nodes=nodes.values(), inputs=inputs, adapter=adapter)
assert actual == expected
actual = graph_functions.execute_subdag(
nodes=nodes.values(), inputs=inputs, adapter=adapter, overrides={"A": 8}
)
assert actual["A"] == 8
def test_get_required_functions():
"""Exercises getting the subset of the graph for computation on the toy example we have constructed."""
nodes = create_testing_nodes()
final_vars = ["A", "B"]
expected_user_nodes = {nodes["b"], nodes["c"]}
expected_nodes = {nodes["A"], nodes["B"], nodes["b"], nodes["c"]} # we skip 'C'
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
actual_nodes, actual_ud_nodes = fg.get_upstream_nodes(final_vars)
assert actual_nodes == expected_nodes
assert actual_ud_nodes == expected_user_nodes
def test_get_downstream_nodes():
"""Exercises getting the downstream subset of the graph for computation on the toy example we have constructed."""
nodes = create_testing_nodes()
var_changes = ["A"]
expected_nodes = {nodes["B"], nodes["C"], nodes["A"]}
# expected_nodes = {nodes['A'], nodes['B'], nodes['b'], nodes['c']} # we skip 'C'
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
actual_nodes = fg.get_downstream_nodes(var_changes)
assert actual_nodes == expected_nodes
def test_function_graph_from_multiple_sources():
fg = graph.FunctionGraph.from_modules(
tests.resources.dummy_functions, tests.resources.parametrized_nodes, config={}
)
assert len(fg.get_nodes()) == 8 # we take the union of all of them, and want to test that
def test_end_to_end_with_parametrized_nodes():
"""Tests that a simple function graph with parametrized nodes works end-to-end"""
fg = graph.FunctionGraph.from_modules(tests.resources.parametrized_nodes, config={})
results = fg.execute(fg.get_nodes(), {})
assert results == {"parametrized_1": 1, "parametrized_2": 2, "parametrized_3": 3}
def test_end_to_end_with_parametrized_inputs():
fg = graph.FunctionGraph.from_modules(
tests.resources.parametrized_inputs, config={"static_value": 3}
)
results = fg.execute(fg.get_nodes())
assert results == {
"input_1": 1,
"input_2": 2,
"input_3": 3,
"output_1": 1 + 3,
"output_2": 2 + 3,
"output_12": 1 + 2 + 3,
"output_123": 1 + 2 + 3 + 3,
"static_value": 3,
}
def test_get_required_functions_askfor_config():
"""Tests that a simple function graph with parametrized nodes works end-to-end"""
fg = graph.FunctionGraph.from_modules(tests.resources.parametrized_nodes, config={"a": 1})
nodes, user_nodes = fg.get_upstream_nodes(["a", "parametrized_1"])
(n,) = user_nodes
assert n.name == "a"
results = fg.execute(user_nodes)
assert results == {"a": 1}
def test_end_to_end_with_column_extractor_nodes():
"""Tests that a simple function graph with nodes that extract columns works end-to-end"""
fg = graph.FunctionGraph.from_modules(tests.resources.extract_column_nodes, config={})
nodes = fg.get_nodes()
results = fg.execute(nodes, {}, {})
df_expected = tests.resources.extract_column_nodes.generate_df()
pd.testing.assert_series_equal(results["col_1"], df_expected["col_1"])
pd.testing.assert_series_equal(results["col_2"], df_expected["col_2"])
pd.testing.assert_frame_equal(results["generate_df"], df_expected)
assert (
nodes[0].documentation == "Function that should be parametrized to form multiple functions"
)
def test_end_to_end_with_multiple_decorators():
"""Tests that a simple function graph with multiple decorators on a function works end-to-end"""
fg = graph.FunctionGraph.from_modules(
tests.resources.multiple_decorators_together,
config={"param0": 3, "param1": 1, "in_value1": 42, "in_value2": "string_value"},
)
nodes = fg.get_nodes()
# To help debug issues:
# nodez, user_nodes = fg.get_upstream_nodes([n.name for n in nodes],
# {"param0": 3, "param1": 1,
# "in_value1": 42, "in_value2": "string_value"})
# fg.display(
# nodez,
# user_nodes,
# "all_multiple_decorators",
# render_kwargs=None,
# graphviz_kwargs=None,
# )
results = fg.execute(nodes, {}, {})
df_expected = tests.resources.multiple_decorators_together._sum_multiply(3, 1, 2)
dict_expected = tests.resources.multiple_decorators_together._sum(3, 1, 2)
pd.testing.assert_series_equal(results["param1b"], df_expected["param1b"])
pd.testing.assert_frame_equal(results["to_modify"], df_expected)
assert results["total"] == dict_expected["total"]
assert results["to_modify_2"] == dict_expected
node_dict = {n.name: n for n in nodes}
print(sorted(list(node_dict.keys())))
assert (
node_dict["to_modify"].documentation
== "This is a dummy function showing extract_columns with does."
)
assert (
node_dict["to_modify_2"].documentation
== "This is a dummy function showing extract_fields with does."
)
# tag only applies right now to outer most node layer
assert node_dict["uber_decorated_function"].tags == {
"module": "tests.resources.multiple_decorators_together"
} # tags are not propagated
assert node_dict["out_value1"].tags == {
"module": "tests.resources.multiple_decorators_together",
"test_key": "test-value",
}
assert node_dict["out_value2"].tags == {
"module": "tests.resources.multiple_decorators_together",
"test_key": "test-value",
}
def test_end_to_end_with_config_modifier():
config = {
"fn_1_version": 1,
}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results["fn"] == "version_1"
config = {
"fn_1_version": 2,
}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results["fn"] == "version_2"
config = {
"fn_1_version": 3,
}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results["fn"] == "version_3"
def test_non_required_nodes():
fg = graph.FunctionGraph.from_modules(
tests.resources.test_default_args, config={"required": 10}
)
results = fg.execute(
# D is not on the execution path, so it should not break things
[n for n in fg.get_nodes() if n.node_role == NodeType.STANDARD and n.name != "D"],
{},
{},
)
assert results["A"] == 10
fg = graph.FunctionGraph.from_modules(
tests.resources.test_default_args, config={"required": 10, "defaults_to_zero": 1}
)
results = fg.execute(
[n for n in fg.get_nodes() if n.node_role == NodeType.STANDARD],
{},
{},
)
assert results["A"] == 11
assert results["D"] == 2
def test_config_can_override():
config = {"new_param": "new_value"}
fg = graph.FunctionGraph.from_modules(tests.resources.config_modifier, config=config)
out = fg.execute([n for n in fg.get_nodes()])
assert out["new_param"] == "new_value"
def test_function_graph_has_cycles_true():
"""Tests whether we catch a graph with cycles -- and expected behaviors"""
fg = graph.FunctionGraph.from_modules(tests.resources.cyclic_functions, config={"b": 2, "c": 1})
all_nodes = fg.get_nodes()
nodes = [n for n in all_nodes if not n.user_defined]
user_nodes = [n for n in all_nodes if n.user_defined]
assert fg.has_cycles(nodes, user_nodes) is True
required_nodes, required_user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
assert required_nodes == set(nodes + user_nodes)
assert required_user_nodes == set(user_nodes)
# We don't want to support this behavior officially -- but this works:
# result = fg.execute([n for n in nodes if n.name == 'B'], overrides={'A': 1, 'D': 2})
# assert len(result) == 3
# assert result['B'] == 3
with pytest.raises(
RecursionError
): # throw recursion error when we don't have a way to short circuit
fg.execute([n for n in nodes if n.name == "B"])
def test_function_graph_has_cycles_false():
"""Tests whether we catch a graph with cycles"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
all_nodes = fg.get_nodes()
# checks it two ways
nodes = [n for n in all_nodes if not n.user_defined]
user_nodes = [n for n in all_nodes if n.user_defined]
assert fg.has_cycles(nodes, user_nodes) is False
# this is called by the driver
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
assert fg.has_cycles(nodes, user_nodes) is False
def test_function_graph_display():
"""Tests that display saves a file"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
all_nodes = set()
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
expected = sorted(
[
"// Dependency Graph\n",
"digraph {\n",
"\tA [label=A]\n",
"\tC [label=C]\n",
"\tB [label=B shape=rectangle]\n",
'\tc [label="Input: c" style=dashed]\n',
'\tb [label="Input: b" style=dashed]\n',
"\tb -> A\n",
"\tc -> A\n",
"\tA -> C\n",
"\tA -> B\n",
"}\n",
]
)
with tempfile.TemporaryDirectory() as tmp_dir:
path = tmp_dir.join("test.dot")
fg.display(all_nodes, str(path), {"view": False}, None, node_modifiers)
with open(str(path), "r") as dot_file:
actual = sorted(dot_file.readlines())
assert actual == expected
def test_function_graph_display_without_saving():
"""Tests that display works when None is passed in for path"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
all_nodes = set()
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
digraph = fg.display(all_nodes, None, node_modifiers=node_modifiers)
assert digraph is not None
import graphviz
assert isinstance(digraph, graphviz.Digraph)
def test_create_graphviz_graph():
"""Tests that we create a graphviz graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
nodez = nodes.union(user_nodes)
node_modifiers = {
"b": {graph.VisualizationNodeModifiers.IS_USER_INPUT},
"c": {graph.VisualizationNodeModifiers.IS_USER_INPUT},
"B": {graph.VisualizationNodeModifiers.IS_OUTPUT},
}
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
expected = sorted(
[
"// test-graph",
"digraph {",
"\tgraph [ratio=1]",
"\tB [label=B shape=rectangle]",
"\tA [label=A]",
"\tC [label=C]",
'\tb [label="Input: b" style=dashed]',
'\tc [label="Input: c" style=dashed]',
"\tA -> B",
"\tb -> A",
"\tc -> A",
"\tA -> C",
"}",
"",
]
)
if "" in expected:
expected.remove("")
digraph = graph.create_graphviz_graph(
nodez, "test-graph", dict(graph_attr={"ratio": "1"}), node_modifiers, False
)
actual = sorted(str(digraph).split("\n"))
if "" in actual:
actual.remove("")
assert actual == expected
def test_create_networkx_graph():
"""Tests that we create a networkx graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
digraph = graph.create_networkx_graph(nodes, user_nodes, "test-graph")
expected_nodes = sorted(["c", "B", "C", "b", "A"])
expected_edges = sorted([("c", "A"), ("b", "A"), ("A", "B"), ("A", "C")])
assert sorted(list(digraph.nodes)) == expected_nodes
assert sorted(list(digraph.edges)) == expected_edges
def test_end_to_end_with_layered_decorators_resolves_true():
fg = graph.FunctionGraph.from_modules(
tests.resources.layered_decorators, config={"foo": "bar", "d": 10, "b": 20}
)
out = fg.execute([n for n in fg.get_nodes()])
assert len(out) > 0 # test config.when resolves correctly
assert out["e"] == (20 + 10)
assert out["f"] == (20 + 20)
def test_end_to_end_with_layered_decorators_resolves_false():
config = {"foo": "not_bar", "d": 10, "b": 20}
fg = graph.FunctionGraph.from_modules(tests.resources.layered_decorators, config=config)
out = fg.execute(
[n for n in fg.get_nodes()],
)
assert {item: value for item, value in out.items() if item not in config} == {}
def test_combine_inputs_no_collision():
"""Tests the combine_and_validate_inputs functionality when there are no collisions"""
combined = graph_functions.combine_config_and_inputs({"a": 1}, {"b": 2})
assert combined == {"a": 1, "b": 2}
def test_combine_inputs_collision():
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys but not values"""
with pytest.raises(ValueError):
graph_functions.combine_config_and_inputs({"a": 1}, {"a": 2})
def test_combine_inputs_collision_2():
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys and values"""
with pytest.raises(ValueError):
graph_functions.combine_config_and_inputs({"a": 1}, {"a": 1})
def test_extract_columns_executes_once():
"""Ensures that extract_columns only computes the function once.
Note this is a bit heavy-handed of a test but its nice to have."""
fg = graph.FunctionGraph.from_modules(
tests.resources.extract_columns_execution_count, config={}
)
unique_id = str(uuid.uuid4())
fg.execute([n for n in fg.get_nodes()], inputs={"unique_id": unique_id})
assert (
len(tests.resources.extract_columns_execution_count.outputs[unique_id]) == 1
) # It should only be called once
def test_end_to_end_with_generics():
fg = graph.FunctionGraph.from_modules(
tests.resources.functions_with_generics, config={"b": {}, "c": 1}
)
results = fg.execute(fg.get_nodes())
assert results == {
"A": ({}, 1),
"B": [({}, 1), ({}, 1)],
"C": {"foo": [({}, 1), ({}, 1)]},
"D": 1.0,
"b": {},
"c": 1,
}
@pytest.mark.parametrize(
"config,inputs,overrides",
[
# testing with no provided inputs
({}, {}, {}),
# testing with just configs
({"a": 11}, {}, {}),
({"b": 13, "a": 17}, {}, {}),
({"b": 19, "a": 23, "d": 29, "f": 31}, {}, {}),
# Testing with just inputs
({}, {"a": 37}, {}),
({}, {"b": 41, "a": 43}, {}),
({}, {"b": 41, "a": 43, "d": 47, "f": 53}, {}),
# Testing with just overrides
# TBD whether these should be legitimate -- can we override required inputs?
# Test works now but not part of the contract
# ({}, {}, {'a': 59}),
# ({}, {}, {'a': 61, 'b': 67}),
# ({}, {}, {'a': 71, 'b': 73, 'd': 79, 'f': 83}),
# testing with a mix
({"a": 89}, {"b": 97}, {}),
({"a": 101}, {"b": 103, "d": 107}, {}),
],
)
def test_optional_execute(config, inputs, overrides):
"""Tests execution of optionals with different assortment of overrides, configs, inputs, etc...
Be careful adding tests with conflicting values between them.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config=config)
# we put a user input node first to ensure that order does not matter with computation order.
results = fg.execute([fg.nodes["b"], fg.nodes["g"]], inputs=inputs, overrides=overrides)
do_all_args = {key + "_val": val for key, val in {**config, **inputs, **overrides}.items()}
expected_results = tests.resources.optional_dependencies._do_all(**do_all_args)
assert results["g"] == expected_results["g"]
def test_optional_input_behavior():
"""Tests that if we request optional user inputs that are not provided, we do not break. And if they are
we do the right thing and return them.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
# nothing passed, so nothing returned
result = fg.execute([fg.nodes["b"], fg.nodes["a"]], inputs={}, overrides={})
assert result == {}
# something passed, something returned via config
fg2 = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={"a": 10})
result = fg2.execute([fg.nodes["b"], fg.nodes["a"]], inputs={}, overrides={})
assert result == {"a": 10}
# something passed, something returned via inputs
result = fg.execute([fg.nodes["b"], fg.nodes["a"]], inputs={"a": 10}, overrides={})
assert result == {"a": 10}
# something passed, something returned via overrides
result = fg.execute([fg.nodes["b"], fg.nodes["a"]], inputs={}, overrides={"a": 10})
assert result == {"a": 10}
@pytest.mark.parametrize("node_order", list(permutations("fhi")))
def test_user_input_breaks_if_required_missing(node_order):
"""Tests that we break because `h` is required but is not passed in."""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
permutation = [fg.nodes[n] for n in node_order]
with pytest.raises(NotImplementedError):
fg.execute(permutation, inputs={}, overrides={})
@pytest.mark.parametrize("node_order", list(permutations("fhi")))
def test_user_input_does_not_break_if_required_provided(node_order):
"""Tests that things work no matter the order because `h` is required and is passed in, while `f` is optional."""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={"h": 10})
permutation = [fg.nodes[n] for n in node_order]
result = fg.execute(permutation, inputs={}, overrides={})
assert result == {"h": 10, "i": 17}
def test_optional_donot_drop_none():
"""We do not want to drop `None` results from functions. We want to pass them through to the function.
This is here to enshrine the current behavior.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={"h": None})
# enshrine behavior that None is not removed from being passed to the function.
results = fg.execute([fg.nodes["h"], fg.nodes["i"]], inputs={}, overrides={})
assert results == {"h": None, "i": 17}
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
results = fg.execute(
[fg.nodes["j"], fg.nodes["none_result"], fg.nodes["f"]], inputs={}, overrides={}
)
assert results == {"j": None, "none_result": None} # f omitted cause it's optional.
def test_optional_get_required_compile_time():
"""Tests that getting required with optionals at compile time returns everything
TODO -- change this to be testing a different function (compile time) than runtime.
"""
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(["g"])
assert len(all_upstream) == 7 # 6 total nodes upstream
assert len(user_required) == 4 # 4 nodes required input
def test_optional_get_required_runtime():
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(["g"], runtime_inputs={}) # Nothng required
assert len(all_upstream) == 3 # 6 total nodes upstream
assert len(user_required) == 0 # 4 nodes required input
def test_optional_get_required_runtime_with_provided():
fg = graph.FunctionGraph.from_modules(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(
["g"], runtime_inputs={"b": 109}
) # Nothng required
assert len(all_upstream) == 4 # 6 total nodes upstream
assert len(user_required) == 1 # 4 nodes required input
def test_in_driver_function_definitions():
"""Tests that we can instantiate a DAG with a function defined in the driver, e.g. notebook context"""
def my_function(A: int, b: int, c: int) -> int:
"""Function for input below"""
return A + b + c
f_module = ad_hoc_utils.create_temporary_module(my_function)
fg = graph.FunctionGraph.from_modules(
tests.resources.dummy_functions, f_module, config={"b": 3, "c": 1}
)
results = fg.execute([n for n in fg.get_nodes() if n.name in ["my_function", "A"]])
assert results == {"A": 4, "b": 3, "c": 1, "my_function": 8}
def test_update_dependencies():
nodes = create_testing_nodes()
new_nodes = graph.update_dependencies(nodes, base.DefaultAdapter())
for node_name, node_ in new_nodes.items():
assert node_.dependencies == nodes[node_name].dependencies
assert node_.depended_on_by == nodes[node_name].depended_on_by