blob: 494359a52a520feadf47e94f057ebec8db681546 [file] [log] [blame]
import inspect
from itertools import permutations
import tempfile
import typing
import uuid
import pandas as pd
import pytest
import tests.resources.bad_functions
import tests.resources.config_modifier
import tests.resources.cyclic_functions
import tests.resources.dummy_functions
import tests.resources.extract_column_nodes
import tests.resources.extract_columns_execution_count
import tests.resources.functions_with_generics
import tests.resources.layered_decorators
import tests.resources.optional_dependencies
import tests.resources.parametrized_inputs
import tests.resources.parametrized_nodes
import tests.resources.typing_vs_not_typing
from hamilton import graph, base, ad_hoc_utils
from hamilton import node
from hamilton.node import NodeSource
def test_find_functions():
"""Tests that we filter out _ functions when passed a module and don't pull in anything from the imports."""
expected = [('A', tests.resources.dummy_functions.A),
('B', tests.resources.dummy_functions.B),
('C', tests.resources.dummy_functions.C)]
actual = graph.find_functions(tests.resources.dummy_functions)
assert len(actual) == len(expected)
assert actual == expected
def test_find_functions_from_temporary_function_module():
"""Tests that we handle the TemporaryFunctionModule object correctly."""
expected = [('A', tests.resources.dummy_functions.A),
('B', tests.resources.dummy_functions.B),
('C', tests.resources.dummy_functions.C)]
func_module = ad_hoc_utils.create_temporary_module(tests.resources.dummy_functions.A,
tests.resources.dummy_functions.B,
tests.resources.dummy_functions.C)
actual = graph.find_functions(func_module)
assert len(actual) == len(expected)
assert [node_name for node_name, _ in actual] == [node_name for node_name, _ in expected]
assert [fn.__code__ for _, fn in actual] == [fn.__code__ for _, fn in expected] # easy way to say they're the same
def test_add_dependency_missing_param_type():
"""Tests case that we error if types are missing from a parameter."""
with pytest.raises(ValueError):
a_sig = inspect.signature(tests.resources.bad_functions.A)
node.Node('A', a_sig.return_annotation, 'A doc', tests.resources.bad_functions.A) # should error out
def test_add_dependency_missing_function_type():
"""Tests case that we error if types are missing from a function."""
with pytest.raises(ValueError):
b_sig = inspect.signature(tests.resources.bad_functions.B)
node.Node('B', b_sig.return_annotation, 'B doc', tests.resources.bad_functions.B) # should error out
def test_add_dependency_strict_node_dependencies():
"""Tests that we add node dependencies between functions correctly.
Setup here is: B depends on A. So A is depended on by B. B is not depended on by anyone.
"""
b_sig = inspect.signature(tests.resources.dummy_functions.B)
func_node = node.Node('B', b_sig.return_annotation, 'B doc', tests.resources.dummy_functions.B)
func_name = 'B'
nodes = {'A': node.Node('A', inspect.signature(tests.resources.dummy_functions.A).return_annotation, 'A doc',
tests.resources.dummy_functions.A)}
param_name = 'A'
param_type = b_sig.parameters['A'].annotation
graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter())
assert nodes['A'] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_typing_to_primitive_conversion():
"""Tests that we can mix function output being typing type, and dependent function using primitive type."""
b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B)
func_node = node.Node('B', b_sig.return_annotation, 'B doc', tests.resources.typing_vs_not_typing.B)
func_name = 'B'
nodes = {'A': node.Node('A', inspect.signature(tests.resources.typing_vs_not_typing.A).return_annotation, 'A doc',
tests.resources.typing_vs_not_typing.A)}
param_name = 'A'
param_type = b_sig.parameters['A'].annotation
graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter())
assert nodes['A'] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_primitive_to_typing_conversion():
"""Tests that we can mix function output being a primitive type, and dependent function using typing type."""
b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B2)
func_node = node.Node('B2', b_sig.return_annotation, 'B2 doc', tests.resources.typing_vs_not_typing.B2)
func_name = 'B2'
nodes = {'A2': node.Node('A2', inspect.signature(tests.resources.typing_vs_not_typing.A2).return_annotation,
'A2 doc', tests.resources.typing_vs_not_typing.A2)}
param_name = 'A2'
param_type = b_sig.parameters['A2'].annotation
graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter())
assert nodes['A2'] == func_node.dependencies[0]
assert func_node.depended_on_by == []
def test_throwing_error_on_incompatible_types():
"""Tests we error on incompatible types."""
d_sig = inspect.signature(tests.resources.bad_functions.D)
func_node = node.Node('D', d_sig.return_annotation, 'D doc', tests.resources.bad_functions.D)
func_name = 'D'
nodes = {'C': node.Node('C', inspect.signature(tests.resources.bad_functions.C).return_annotation,
'C doc', tests.resources.bad_functions.C)}
param_name = 'C'
param_type = d_sig.parameters['C'].annotation
with pytest.raises(ValueError):
graph.add_dependency(func_node, func_name, nodes, param_name, param_type,
base.SimplePythonDataFrameGraphAdapter())
def test_add_dependency_user_nodes():
"""Tests that we add node user defined dependencies correctly.
Setup here is: A depends on b and c. But we're only doing one call. So expecting A having 'b' as a dependency,
and 'b' is depended on by A.
"""
a_sig = inspect.signature(tests.resources.dummy_functions.A)
func_node = node.Node('A', a_sig.return_annotation, 'A doc', tests.resources.dummy_functions.A)
func_name = 'A'
nodes = {}
param_name = 'b'
param_type = a_sig.parameters['b'].annotation
graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter())
# user node is created and added to nodes.
assert nodes['b'] == func_node.dependencies[0]
assert nodes['b'].depended_on_by[0] == func_node
assert func_node.depended_on_by == []
def test_create_function_graph_simple():
"""Tests that we create a simple function graph."""
expected = create_testing_nodes()
actual = graph.create_function_graph(tests.resources.dummy_functions, config={},
adapter=base.SimplePythonDataFrameGraphAdapter())
assert actual == expected
def create_testing_nodes():
"""Helper function for creating the nodes represented in dummy_functions.py."""
nodes = {'A': node.Node('A',
inspect.signature(tests.resources.dummy_functions.A).return_annotation,
'Function that should become part of the graph - A',
tests.resources.dummy_functions.A, tags={'module': 'tests.resources.dummy_functions'}),
'B': node.Node('B',
inspect.signature(tests.resources.dummy_functions.B).return_annotation,
'Function that should become part of the graph - B',
tests.resources.dummy_functions.B, tags={'module': 'tests.resources.dummy_functions'}),
'C': node.Node('C',
inspect.signature(tests.resources.dummy_functions.C).return_annotation,
'',
tests.resources.dummy_functions.C, tags={'module': 'tests.resources.dummy_functions'}),
'b': node.Node('b',
inspect.signature(tests.resources.dummy_functions.A).parameters['b'].annotation,
node_source=NodeSource.EXTERNAL),
'c': node.Node('c',
inspect.signature(tests.resources.dummy_functions.A).parameters['c'].annotation,
node_source=NodeSource.EXTERNAL)}
nodes['A'].dependencies.append(nodes['b'])
nodes['A'].dependencies.append(nodes['c'])
nodes['A'].depended_on_by.append(nodes['B'])
nodes['A'].depended_on_by.append(nodes['C'])
nodes['b'].depended_on_by.append(nodes['A'])
nodes['c'].depended_on_by.append(nodes['A'])
nodes['B'].dependencies.append(nodes['A'])
nodes['C'].dependencies.append(nodes['A'])
return nodes
def test_execute():
"""Tests graph execution along with basic memoization since A is depended on by two functions."""
adapter = base.SimplePythonDataFrameGraphAdapter()
nodes = create_testing_nodes()
inputs = {
'b': 2,
'c': 5
}
expected = {'A': 7, 'B': 49, 'C': 14, 'b': 2, 'c': 5}
actual = graph.FunctionGraph.execute_static(nodes.values(), inputs, adapter)
assert actual == expected
actual = graph.FunctionGraph.execute_static(nodes.values(), inputs, adapter, overrides={'A': 8})
assert actual['A'] == 8
def test_get_required_functions():
"""Exercises getting the subset of the graph for computation on the toy example we have constructed."""
nodes = create_testing_nodes()
final_vars = ['A', 'B']
expected_user_nodes = {nodes['b'], nodes['c']}
expected_nodes = {nodes['A'], nodes['B'], nodes['b'], nodes['c']} # we skip 'C'
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={})
actual_nodes, actual_ud_nodes = fg.get_upstream_nodes(final_vars)
assert actual_nodes == expected_nodes
assert actual_ud_nodes == expected_user_nodes
def test_get_impacted_nodes():
"""Exercises getting the downstream subset of the graph for computation on the toy example we have constructed."""
nodes = create_testing_nodes()
var_changes = ['A']
expected_nodes = {nodes['B'], nodes['C'], nodes['A']}
# expected_nodes = {nodes['A'], nodes['B'], nodes['b'], nodes['c']} # we skip 'C'
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={})
actual_nodes = fg.get_impacted_nodes(var_changes)
assert actual_nodes == expected_nodes
def test_function_graph_from_multiple_sources():
fg = graph.FunctionGraph(tests.resources.dummy_functions, tests.resources.parametrized_nodes, config={})
assert len(fg.get_nodes()) == 8 # we take the union of all of them, and want to test that
def test_end_to_end_with_parametrized_nodes():
"""Tests that a simple function graph with parametrized nodes works end-to-end"""
fg = graph.FunctionGraph(tests.resources.parametrized_nodes, config={})
results = fg.execute(fg.get_nodes(), {})
assert results == {'parametrized_1': 1, 'parametrized_2': 2, 'parametrized_3': 3}
def test_end_to_end_with_parametrized_inputs():
fg = graph.FunctionGraph(tests.resources.parametrized_inputs, config={'static_value': 3})
results = fg.execute(fg.get_nodes())
assert results == {
'input_1': 1,
'input_2': 2,
'input_3': 3,
'output_1': 1 + 3,
'output_2': 2 + 3,
'output_12': 1 + 2 + 3,
'output_123': 1 + 2 + 3 + 3,
'static_value': 3
}
def test_get_required_functions_askfor_config():
"""Tests that a simple function graph with parametrized nodes works end-to-end"""
fg = graph.FunctionGraph(tests.resources.parametrized_nodes, config={'a': 1})
nodes, user_nodes = fg.get_upstream_nodes(['a', 'parametrized_1'])
n, = user_nodes
assert n.name == 'a'
results = fg.execute(user_nodes)
assert results == {'a': 1}
def test_end_to_end_with_column_extractor_nodes():
"""Tests that a simple function graph with nodes that extract columns works end-to-end"""
fg = graph.FunctionGraph(tests.resources.extract_column_nodes, config={})
nodes = fg.get_nodes()
results = fg.execute(nodes, {}, {})
df_expected = tests.resources.extract_column_nodes.generate_df()
pd.testing.assert_series_equal(results['col_1'], df_expected['col_1'])
pd.testing.assert_series_equal(results['col_2'], df_expected['col_2'])
pd.testing.assert_frame_equal(results['generate_df'], df_expected)
assert nodes[0].documentation == 'Function that should be parametrized to form multiple functions'
def test_end_to_end_with_config_modifier():
config = {
'fn_1_version': 1,
}
fg = graph.FunctionGraph(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results['fn'] == 'version_1'
config = {
'fn_1_version': 2,
}
fg = graph.FunctionGraph(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results['fn'] == 'version_2'
config = {
'fn_1_version': 3,
}
fg = graph.FunctionGraph(tests.resources.config_modifier, config=config)
results = fg.execute(fg.get_nodes(), {}, {})
assert results['fn'] == 'version_3'
def test_non_required_nodes():
fg = graph.FunctionGraph(tests.resources.test_default_args, config={'required': 10})
results = fg.execute([n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD], {}, {})
assert results['A'] == 10
fg = graph.FunctionGraph(tests.resources.test_default_args, config={'required': 10, 'defaults_to_zero': 1})
results = fg.execute([n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD], {}, {})
assert results['A'] == 11
def test_config_can_override():
config = {
'new_param': 'new_value'
}
fg = graph.FunctionGraph(tests.resources.config_modifier, config=config)
out = fg.execute([n for n in fg.get_nodes()])
assert out['new_param'] == 'new_value'
def test_function_graph_has_cycles_true():
"""Tests whether we catch a graph with cycles -- and expected behaviors"""
fg = graph.FunctionGraph(tests.resources.cyclic_functions, config={'b': 2, 'c': 1})
all_nodes = fg.get_nodes()
nodes = [n for n in all_nodes if not n.user_defined]
user_nodes = [n for n in all_nodes if n.user_defined]
assert fg.has_cycles(nodes, user_nodes) is True
required_nodes, required_user_nodes = fg.get_upstream_nodes(['A', 'B', 'C'])
assert required_nodes == set(nodes + user_nodes)
assert required_user_nodes == set(user_nodes)
# We don't want to support this behavior officially -- but this works:
# result = fg.execute([n for n in nodes if n.name == 'B'], overrides={'A': 1, 'D': 2})
# assert len(result) == 3
# assert result['B'] == 3
with pytest.raises(RecursionError): # throw recursion error when we don't have a way to short circuit
fg.execute([n for n in nodes if n.name == 'B'])
def test_function_graph_has_cycles_false():
"""Tests whether we catch a graph with cycles"""
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={'b': 1, 'c': 2})
all_nodes = fg.get_nodes()
# checks it two ways
nodes = [n for n in all_nodes if not n.user_defined]
user_nodes = [n for n in all_nodes if n.user_defined]
assert fg.has_cycles(nodes, user_nodes) is False
# this is called by the driver
nodes, user_nodes = fg.get_upstream_nodes(['A', 'B', 'C'])
assert fg.has_cycles(nodes, user_nodes) is False
def test_function_graph_display():
"""Tests that display saves a file"""
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={'b': 1, 'c': 2})
defined_nodes = set()
user_nodes = set()
for n in fg.get_nodes():
if n.user_defined:
user_nodes.add(n)
else:
defined_nodes.add(n)
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
expected = sorted(['// Dependency Graph\n',
'digraph {\n',
'\tA [label=A]\n',
'\tC [label=C]\n',
'\tB [label=B]\n',
'\tc [label="UD: c"]\n',
'\tb [label="UD: b"]\n',
'\tb -> A\n',
'\tc -> A\n',
'\tA -> C\n',
'\tA -> B\n',
'}\n'])
with tempfile.TemporaryDirectory() as tmp_dir:
path = tmp_dir.join('test.dot')
fg.display(defined_nodes, user_nodes, str(path), {'view': False})
with open(str(path), 'r') as dot_file:
actual = sorted(dot_file.readlines())
assert actual == expected
def test_create_graphviz_graph():
"""Tests that we create a graphviz graph"""
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={})
nodes, user_nodes = fg.get_upstream_nodes(['A', 'B', 'C'])
# hack of a test -- but it works... sort the lines and match them up.
# why? because for some reason given the same graph, the output file isn't deterministic.
expected = sorted(['// test-graph',
'digraph {',
'\tgraph [ratio=1]',
'\tB [label=B]',
'\tA [label=A]',
'\tc [label=c]',
'\tC [label=C]',
'\tb [label=b]',
'\tb [label="UD: b"]',
'\tc [label="UD: c"]',
'\tA -> B',
'\tb -> A',
'\tc -> A',
'\tA -> C',
'}',
''])
if '' in expected:
expected.remove('')
digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph', dict(graph_attr={'ratio': '1'}))
actual = sorted(str(digraph).split('\n'))
if '' in actual:
actual.remove('')
assert actual == expected
def test_create_networkx_graph():
"""Tests that we create a networkx graph"""
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={})
nodes, user_nodes = fg.get_upstream_nodes(['A', 'B', 'C'])
digraph = graph.create_networkx_graph(nodes, user_nodes, 'test-graph')
expected_nodes = sorted(['c', 'B', 'C', 'b', 'A'])
expected_edges = sorted([('c', 'A'), ('b', 'A'), ('A', 'B'), ('A', 'C')])
assert sorted(list(digraph.nodes)) == expected_nodes
assert sorted(list(digraph.edges)) == expected_edges
def test_end_to_end_with_layered_decorators_resolves_true():
fg = graph.FunctionGraph(tests.resources.layered_decorators, config={'foo': 'bar', 'd': 10, 'b': 20})
out = fg.execute([n for n in fg.get_nodes()])
assert len(out) > 0 # test config.when resolves correctly
assert out['e'] == (20 + 10)
assert out['f'] == (20 + 20)
def test_end_to_end_with_layered_decorators_resolves_false():
config = {'foo': 'not_bar', 'd': 10, 'b': 20}
fg = graph.FunctionGraph(tests.resources.layered_decorators, config=config)
out = fg.execute([n for n in fg.get_nodes()], )
assert {item: value for item, value in out.items() if item not in config} == {}
def test_combine_inputs_no_collision():
"""Tests the combine_and_validate_inputs functionality when there are no collisions"""
combined = graph.FunctionGraph.combine_config_and_inputs({'a': 1}, {'b': 2})
assert combined == {'a': 1, 'b': 2}
def test_combine_inputs_collision():
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys but not values"""
with pytest.raises(ValueError):
graph.FunctionGraph.combine_config_and_inputs({'a': 1}, {'a': 2})
def test_combine_inputs_collision_2():
"""Tests the combine_and_validate_inputs functionality
when there are collisions of keys and values"""
with pytest.raises(ValueError):
graph.FunctionGraph.combine_config_and_inputs({'a': 1}, {'a': 1})
def test_extract_columns_executes_once():
"""Ensures that extract_columns only computes the function once.
Note this is a bit heavy-handed of a test but its nice to have."""
fg = graph.FunctionGraph(tests.resources.extract_columns_execution_count, config={})
unique_id = str(uuid.uuid4())
fg.execute([n for n in fg.get_nodes()], inputs={'unique_id': unique_id})
assert len(tests.resources.extract_columns_execution_count.outputs[unique_id]) == 1 # It should only be called once
class X():
pass
class Y(X):
pass
custom_type = typing.TypeVar('FOOBAR')
@pytest.mark.parametrize('param_type,required_type,expected', [
(custom_type, custom_type, True),
(custom_type, typing.TypeVar('FOO'), False),
(int, int, True),
(int, float, False),
(typing.List[int], typing.List, True),
(typing.List, typing.List[float], True),
(typing.List, list, True),
(typing.Dict, dict, True),
(dict, typing.Dict, True),
(list, typing.List, True),
(list, typing.List, True),
(typing.List[int], typing.List[float], False),
(typing.Dict, typing.List, False),
(typing.Mapping, typing.Dict, True),
(typing.Mapping, dict, True),
(dict, typing.Mapping, False),
(typing.Dict, typing.Mapping, False),
(typing.Iterable, typing.List, True),
(typing.Tuple[str, str], typing.Tuple[str, str], True),
(typing.Tuple[str, str], typing.Tuple[str], False),
(typing.Tuple[str, str], typing.Tuple, True),
(typing.Tuple, typing.Tuple[str, str], True),
(typing.Union[str, str], typing.Union[str, str], True),
(X, X, True),
(X, Y, True),
(Y, X, False),
])
def test_custom_subclass_check(param_type, required_type, expected):
"""Tests the custom_subclass_check"""
actual = graph.custom_subclass_check(required_type, param_type)
assert actual == expected
class TestAdapter(base.SimplePythonDataFrameGraphAdapter):
@staticmethod
def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool:
# fake equivalence function
return node_type == pd.Series and input_type == list
adapter = TestAdapter()
@pytest.mark.parametrize('adapter,param_type,required_type,expected', [
(None, typing.TypeVar('FOO'), typing.TypeVar('BAR'), False),
(None, custom_type, custom_type, True),
(None, int, int, True),
(adapter, int, float, False),
(None, typing.Dict, typing.Any, True),
(None, X, X, True),
(None, X, Y, True),
(adapter, pd.Series, pd.Series, True),
(adapter, list, pd.Series, True),
(adapter, dict, pd.Series, False),
])
def test_types_match(adapter, param_type, required_type, expected):
"""Tests the types_match function"""
actual = graph.types_match(adapter, param_type, required_type)
assert actual == expected
def test_end_to_end_with_generics():
fg = graph.FunctionGraph(tests.resources.functions_with_generics, config={'b': {}, 'c': 1})
results = fg.execute(fg.get_nodes())
assert results == {
'A': ({}, 1),
'B': [({}, 1), ({}, 1)],
'C': {'foo': [({}, 1), ({}, 1)]},
'D': 1.0,
'b': {},
'c': 1
}
@pytest.mark.parametrize(
'config,inputs,overrides',
[
# testing with no provided inputs
({}, {}, {}),
# testing with just configs
({'a': 11}, {}, {}),
({'b': 13, 'a': 17}, {}, {}),
({'b': 19, 'a': 23, 'd': 29, 'f': 31}, {}, {}),
# Testing with just inputs
({}, {'a': 37}, {}),
({}, {'b': 41, 'a': 43}, {}),
({}, {'b': 41, 'a': 43, 'd': 47, 'f': 53}, {}),
# Testing with just overrides
# TBD whether these should be legitimate -- can we override required inputs?
# Test works now but not part of the contract
# ({}, {}, {'a': 59}),
# ({}, {}, {'a': 61, 'b': 67}),
# ({}, {}, {'a': 71, 'b': 73, 'd': 79, 'f': 83}),
# testing with a mix
({'a': 89}, {'b': 97}, {}),
({'a': 101}, {'b': 103, 'd': 107}, {}),
]
)
def test_optional_execute(config, inputs, overrides):
"""Tests execution of optionals with different assortment of overrides, configs, inputs, etc...
Be careful adding tests with conflicting values between them.
"""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config=config)
# we put a user input node first to ensure that order does not matter with computation order.
results = fg.execute([fg.nodes['b'], fg.nodes['g']], inputs=inputs, overrides=overrides)
do_all_args = {key + '_val': val for key, val in {**config, **inputs, **overrides}.items()}
expected_results = tests.resources.optional_dependencies._do_all(**do_all_args)
assert results['g'] == expected_results['g']
def test_optional_input_behavior():
"""Tests that if we request optional user inputs that are not provided, we do not break. And if they are
we do the right thing and return them.
"""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
# nothing passed, so nothing returned
result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={})
assert result == {}
# something passed, something returned via config
fg2 = graph.FunctionGraph(tests.resources.optional_dependencies, config={'a': 10})
result = fg2.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={})
assert result == {'a': 10}
# something passed, something returned via inputs
result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={'a': 10}, overrides={})
assert result == {'a': 10}
# something passed, something returned via overrides
result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={'a': 10})
assert result == {'a': 10}
@pytest.mark.parametrize('node_order', list(permutations('fhi')))
def test_user_input_breaks_if_required_missing(node_order):
"""Tests that we break because `h` is required but is not passed in."""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
permutation = [
fg.nodes[n] for n in node_order
]
with pytest.raises(NotImplementedError):
fg.execute(permutation, inputs={}, overrides={})
@pytest.mark.parametrize('node_order', list(permutations('fhi')))
def test_user_input_does_not_break_if_required_provided(node_order):
"""Tests that things work no matter the order because `h` is required and is passed in, while `f` is optional."""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': 10})
permutation = [
fg.nodes[n] for n in node_order
]
result = fg.execute(permutation, inputs={}, overrides={})
assert result == {'h': 10, 'i': 17}
def test_optional_donot_drop_none():
"""We do not want to drop `None` results from functions. We want to pass them through to the function.
This is here to enshrine the current behavior.
"""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': None})
# enshrine behavior that None is not removed from being passed to the function.
results = fg.execute([fg.nodes['h'], fg.nodes['i']], inputs={}, overrides={})
assert results == {'h': None, 'i': 17}
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
results = fg.execute([fg.nodes['j'], fg.nodes['none_result'], fg.nodes['f']], inputs={}, overrides={})
assert results == {'j': None, 'none_result': None} # f omitted cause it's optional.
def test_optional_get_required_compile_time():
"""Tests that getting required with optionals at compile time returns everything
TODO -- change this to be testing a different function (compile time) than runtime.
"""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(['g'])
assert len(all_upstream) == 7 # 6 total nodes upstream
assert len(user_required) == 4 # 4 nodes required input
def test_optional_get_required_runtime():
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(['g'], runtime_inputs={}) # Nothng required
assert len(all_upstream) == 3 # 6 total nodes upstream
assert len(user_required) == 0 # 4 nodes required input
def test_optional_get_required_runtime_with_provided():
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
all_upstream, user_required = fg.get_upstream_nodes(['g'], runtime_inputs={'b': 109}) # Nothng required
assert len(all_upstream) == 4 # 6 total nodes upstream
assert len(user_required) == 1 # 4 nodes required input
def test_in_driver_function_definitions():
"""Tests that we can instantiate a DAG with a function defined in the driver, e.g. notebook context"""
def my_function(A: int, b: int, c: int) -> int:
"""Function for input below"""
return A + b + c
f_module = ad_hoc_utils.create_temporary_module(my_function)
fg = graph.FunctionGraph(tests.resources.dummy_functions, f_module, config={'b': 3, 'c': 1})
results = fg.execute([n for n in fg.get_nodes() if n.name in ['my_function', 'A']])
assert results == {'A': 4, 'b': 3, 'c': 1, 'my_function': 8}