| import inspect |
| from itertools import permutations |
| import tempfile |
| import typing |
| import uuid |
| |
| import pandas as pd |
| import pytest |
| |
| import tests.resources.bad_functions |
| import tests.resources.config_modifier |
| import tests.resources.cyclic_functions |
| import tests.resources.dummy_functions |
| import tests.resources.extract_column_nodes |
| import tests.resources.extract_columns_execution_count |
| import tests.resources.functions_with_generics |
| import tests.resources.layered_decorators |
| import tests.resources.optional_dependencies |
| import tests.resources.parametrized_inputs |
| import tests.resources.parametrized_nodes |
| import tests.resources.typing_vs_not_typing |
| from hamilton import graph, base, ad_hoc_utils |
| from hamilton import node |
| from hamilton.node import NodeSource |
| |
| |
| def test_find_functions(): |
| """Tests that we filter out _ functions when passed a module and don't pull in anything from the imports.""" |
| expected = [('A', tests.resources.dummy_functions.A), |
| ('B', tests.resources.dummy_functions.B), |
| ('C', tests.resources.dummy_functions.C)] |
| actual = graph.find_functions(tests.resources.dummy_functions) |
| assert len(actual) == len(expected) |
| assert actual == expected |
| |
| |
| def test_find_functions_from_temporary_function_module(): |
| """Tests that we handle the TemporaryFunctionModule object correctly.""" |
| expected = [('A', tests.resources.dummy_functions.A), |
| ('B', tests.resources.dummy_functions.B), |
| ('C', tests.resources.dummy_functions.C)] |
| func_module = ad_hoc_utils.create_temporary_module(tests.resources.dummy_functions.A, |
| tests.resources.dummy_functions.B, |
| tests.resources.dummy_functions.C) |
| actual = graph.find_functions(func_module) |
| assert len(actual) == len(expected) |
| assert [node_name for node_name, _ in actual] == [node_name for node_name, _ in expected] |
| assert [fn.__code__ for _, fn in actual] == [fn.__code__ for _, fn in expected] # easy way to say they're the same |
| |
| |
| def test_add_dependency_missing_param_type(): |
| """Tests case that we error if types are missing from a parameter.""" |
| with pytest.raises(ValueError): |
| a_sig = inspect.signature(tests.resources.bad_functions.A) |
| node.Node('A', a_sig.return_annotation, 'A doc', tests.resources.bad_functions.A) # should error out |
| |
| |
| def test_add_dependency_missing_function_type(): |
| """Tests case that we error if types are missing from a function.""" |
| with pytest.raises(ValueError): |
| b_sig = inspect.signature(tests.resources.bad_functions.B) |
| node.Node('B', b_sig.return_annotation, 'B doc', tests.resources.bad_functions.B) # should error out |
| |
| |
| def test_add_dependency_strict_node_dependencies(): |
| """Tests that we add node dependencies between functions correctly. |
| |
| Setup here is: B depends on A. So A is depended on by B. B is not depended on by anyone. |
| """ |
| b_sig = inspect.signature(tests.resources.dummy_functions.B) |
| func_node = node.Node('B', b_sig.return_annotation, 'B doc', tests.resources.dummy_functions.B) |
| func_name = 'B' |
| nodes = {'A': node.Node('A', inspect.signature(tests.resources.dummy_functions.A).return_annotation, 'A doc', |
| tests.resources.dummy_functions.A)} |
| param_name = 'A' |
| param_type = b_sig.parameters['A'].annotation |
| |
| graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter()) |
| assert nodes['A'] == func_node.dependencies[0] |
| assert func_node.depended_on_by == [] |
| |
| |
| def test_typing_to_primitive_conversion(): |
| """Tests that we can mix function output being typing type, and dependent function using primitive type.""" |
| b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B) |
| func_node = node.Node('B', b_sig.return_annotation, 'B doc', tests.resources.typing_vs_not_typing.B) |
| func_name = 'B' |
| nodes = {'A': node.Node('A', inspect.signature(tests.resources.typing_vs_not_typing.A).return_annotation, 'A doc', |
| tests.resources.typing_vs_not_typing.A)} |
| param_name = 'A' |
| param_type = b_sig.parameters['A'].annotation |
| |
| graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter()) |
| assert nodes['A'] == func_node.dependencies[0] |
| assert func_node.depended_on_by == [] |
| |
| |
| def test_primitive_to_typing_conversion(): |
| """Tests that we can mix function output being a primitive type, and dependent function using typing type.""" |
| b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B2) |
| func_node = node.Node('B2', b_sig.return_annotation, 'B2 doc', tests.resources.typing_vs_not_typing.B2) |
| func_name = 'B2' |
| nodes = {'A2': node.Node('A2', inspect.signature(tests.resources.typing_vs_not_typing.A2).return_annotation, |
| 'A2 doc', tests.resources.typing_vs_not_typing.A2)} |
| param_name = 'A2' |
| param_type = b_sig.parameters['A2'].annotation |
| |
| graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter()) |
| assert nodes['A2'] == func_node.dependencies[0] |
| assert func_node.depended_on_by == [] |
| |
| |
| def test_throwing_error_on_incompatible_types(): |
| """Tests we error on incompatible types.""" |
| d_sig = inspect.signature(tests.resources.bad_functions.D) |
| func_node = node.Node('D', d_sig.return_annotation, 'D doc', tests.resources.bad_functions.D) |
| func_name = 'D' |
| nodes = {'C': node.Node('C', inspect.signature(tests.resources.bad_functions.C).return_annotation, |
| 'C doc', tests.resources.bad_functions.C)} |
| param_name = 'C' |
| param_type = d_sig.parameters['C'].annotation |
| with pytest.raises(ValueError): |
| graph.add_dependency(func_node, func_name, nodes, param_name, param_type, |
| base.SimplePythonDataFrameGraphAdapter()) |
| |
| |
| def test_add_dependency_user_nodes(): |
| """Tests that we add node user defined dependencies correctly. |
| |
| Setup here is: A depends on b and c. But we're only doing one call. So expecting A having 'b' as a dependency, |
| and 'b' is depended on by A. |
| """ |
| a_sig = inspect.signature(tests.resources.dummy_functions.A) |
| func_node = node.Node('A', a_sig.return_annotation, 'A doc', tests.resources.dummy_functions.A) |
| func_name = 'A' |
| nodes = {} |
| param_name = 'b' |
| param_type = a_sig.parameters['b'].annotation |
| |
| graph.add_dependency(func_node, func_name, nodes, param_name, param_type, base.SimplePythonDataFrameGraphAdapter()) |
| # user node is created and added to nodes. |
| assert nodes['b'] == func_node.dependencies[0] |
| assert nodes['b'].depended_on_by[0] == func_node |
| assert func_node.depended_on_by == [] |
| |
| |
| def test_create_function_graph_simple(): |
| """Tests that we create a simple function graph.""" |
| expected = create_testing_nodes() |
| actual = graph.create_function_graph(tests.resources.dummy_functions, config={}, |
| adapter=base.SimplePythonDataFrameGraphAdapter()) |
| assert actual == expected |
| |
| |
| def create_testing_nodes(): |
| """Helper function for creating the nodes represented in dummy_functions.py.""" |
| nodes = {'A': node.Node('A', |
| inspect.signature(tests.resources.dummy_functions.A).return_annotation, |
| 'Function that should become part of the graph - A', |
| tests.resources.dummy_functions.A, tags={'module': 'tests.resources.dummy_functions'}), |
| 'B': node.Node('B', |
| inspect.signature(tests.resources.dummy_functions.B).return_annotation, |
| 'Function that should become part of the graph - B', |
| tests.resources.dummy_functions.B, tags={'module': 'tests.resources.dummy_functions'}), |
| 'C': node.Node('C', |
| inspect.signature(tests.resources.dummy_functions.C).return_annotation, |
| '', |
| tests.resources.dummy_functions.C, tags={'module': 'tests.resources.dummy_functions'}), |
| 'b': node.Node('b', |
| inspect.signature(tests.resources.dummy_functions.A).parameters['b'].annotation, |
| node_source=NodeSource.EXTERNAL), |
| 'c': node.Node('c', |
| inspect.signature(tests.resources.dummy_functions.A).parameters['c'].annotation, |
| node_source=NodeSource.EXTERNAL)} |
| nodes['A'].dependencies.append(nodes['b']) |
| nodes['A'].dependencies.append(nodes['c']) |
| nodes['A'].depended_on_by.append(nodes['B']) |
| nodes['A'].depended_on_by.append(nodes['C']) |
| nodes['b'].depended_on_by.append(nodes['A']) |
| nodes['c'].depended_on_by.append(nodes['A']) |
| nodes['B'].dependencies.append(nodes['A']) |
| nodes['C'].dependencies.append(nodes['A']) |
| return nodes |
| |
| |
| def test_execute(): |
| """Tests graph execution along with basic memoization since A is depended on by two functions.""" |
| adapter = base.SimplePythonDataFrameGraphAdapter() |
| nodes = create_testing_nodes() |
| inputs = { |
| 'b': 2, |
| 'c': 5 |
| } |
| expected = {'A': 7, 'B': 49, 'C': 14, 'b': 2, 'c': 5} |
| actual = graph.FunctionGraph.execute_static(nodes.values(), inputs, adapter) |
| assert actual == expected |
| actual = graph.FunctionGraph.execute_static(nodes.values(), inputs, adapter, overrides={'A': 8}) |
| assert actual['A'] == 8 |
| |
| |
| def test_get_required_functions(): |
| """Exercises getting the subset of the graph for computation on the toy example we have constructed.""" |
| nodes = create_testing_nodes() |
| final_vars = ['A', 'B'] |
| expected_user_nodes = {nodes['b'], nodes['c']} |
| expected_nodes = {nodes['A'], nodes['B'], nodes['b'], nodes['c']} # we skip 'C' |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, config={}) |
| actual_nodes, actual_ud_nodes = fg.get_upstream_nodes(final_vars) |
| assert actual_nodes == expected_nodes |
| assert actual_ud_nodes == expected_user_nodes |
| |
| |
| def test_get_impacted_nodes(): |
| """Exercises getting the downstream subset of the graph for computation on the toy example we have constructed.""" |
| nodes = create_testing_nodes() |
| var_changes = ['A'] |
| expected_nodes = {nodes['B'], nodes['C'], nodes['A']} |
| # expected_nodes = {nodes['A'], nodes['B'], nodes['b'], nodes['c']} # we skip 'C' |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, config={}) |
| actual_nodes = fg.get_impacted_nodes(var_changes) |
| assert actual_nodes == expected_nodes |
| |
| |
| def test_function_graph_from_multiple_sources(): |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, tests.resources.parametrized_nodes, config={}) |
| assert len(fg.get_nodes()) == 8 # we take the union of all of them, and want to test that |
| |
| |
| def test_end_to_end_with_parametrized_nodes(): |
| """Tests that a simple function graph with parametrized nodes works end-to-end""" |
| fg = graph.FunctionGraph(tests.resources.parametrized_nodes, config={}) |
| results = fg.execute(fg.get_nodes(), {}) |
| assert results == {'parametrized_1': 1, 'parametrized_2': 2, 'parametrized_3': 3} |
| |
| |
| def test_end_to_end_with_parametrized_inputs(): |
| fg = graph.FunctionGraph(tests.resources.parametrized_inputs, config={'static_value': 3}) |
| results = fg.execute(fg.get_nodes()) |
| assert results == { |
| 'input_1': 1, |
| 'input_2': 2, |
| 'input_3': 3, |
| 'output_1': 1 + 3, |
| 'output_2': 2 + 3, |
| 'output_12': 1 + 2 + 3, |
| 'output_123': 1 + 2 + 3 + 3, |
| 'static_value': 3 |
| } |
| |
| |
| def test_get_required_functions_askfor_config(): |
| """Tests that a simple function graph with parametrized nodes works end-to-end""" |
| fg = graph.FunctionGraph(tests.resources.parametrized_nodes, config={'a': 1}) |
| nodes, user_nodes = fg.get_upstream_nodes(['a', 'parametrized_1']) |
| n, = user_nodes |
| assert n.name == 'a' |
| results = fg.execute(user_nodes) |
| assert results == {'a': 1} |
| |
| |
| def test_end_to_end_with_column_extractor_nodes(): |
| """Tests that a simple function graph with nodes that extract columns works end-to-end""" |
| fg = graph.FunctionGraph(tests.resources.extract_column_nodes, config={}) |
| nodes = fg.get_nodes() |
| results = fg.execute(nodes, {}, {}) |
| df_expected = tests.resources.extract_column_nodes.generate_df() |
| pd.testing.assert_series_equal(results['col_1'], df_expected['col_1']) |
| pd.testing.assert_series_equal(results['col_2'], df_expected['col_2']) |
| pd.testing.assert_frame_equal(results['generate_df'], df_expected) |
| assert nodes[0].documentation == 'Function that should be parametrized to form multiple functions' |
| |
| |
| def test_end_to_end_with_config_modifier(): |
| config = { |
| 'fn_1_version': 1, |
| } |
| fg = graph.FunctionGraph(tests.resources.config_modifier, config=config) |
| results = fg.execute(fg.get_nodes(), {}, {}) |
| assert results['fn'] == 'version_1' |
| |
| config = { |
| 'fn_1_version': 2, |
| } |
| fg = graph.FunctionGraph(tests.resources.config_modifier, config=config) |
| results = fg.execute(fg.get_nodes(), {}, {}) |
| assert results['fn'] == 'version_2' |
| config = { |
| 'fn_1_version': 3, |
| } |
| fg = graph.FunctionGraph(tests.resources.config_modifier, config=config) |
| results = fg.execute(fg.get_nodes(), {}, {}) |
| assert results['fn'] == 'version_3' |
| |
| |
| def test_non_required_nodes(): |
| fg = graph.FunctionGraph(tests.resources.test_default_args, config={'required': 10}) |
| results = fg.execute([n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD], {}, {}) |
| assert results['A'] == 10 |
| fg = graph.FunctionGraph(tests.resources.test_default_args, config={'required': 10, 'defaults_to_zero': 1}) |
| results = fg.execute([n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD], {}, {}) |
| assert results['A'] == 11 |
| |
| |
| def test_config_can_override(): |
| config = { |
| 'new_param': 'new_value' |
| } |
| fg = graph.FunctionGraph(tests.resources.config_modifier, config=config) |
| out = fg.execute([n for n in fg.get_nodes()]) |
| assert out['new_param'] == 'new_value' |
| |
| |
| def test_function_graph_has_cycles_true(): |
| """Tests whether we catch a graph with cycles -- and expected behaviors""" |
| fg = graph.FunctionGraph(tests.resources.cyclic_functions, config={'b': 2, 'c': 1}) |
| all_nodes = fg.get_nodes() |
| nodes = [n for n in all_nodes if not n.user_defined] |
| user_nodes = [n for n in all_nodes if n.user_defined] |
| assert fg.has_cycles(nodes, user_nodes) is True |
| required_nodes, required_user_nodes = fg.get_upstream_nodes(['A', 'B', 'C']) |
| assert required_nodes == set(nodes + user_nodes) |
| assert required_user_nodes == set(user_nodes) |
| # We don't want to support this behavior officially -- but this works: |
| # result = fg.execute([n for n in nodes if n.name == 'B'], overrides={'A': 1, 'D': 2}) |
| # assert len(result) == 3 |
| # assert result['B'] == 3 |
| with pytest.raises(RecursionError): # throw recursion error when we don't have a way to short circuit |
| fg.execute([n for n in nodes if n.name == 'B']) |
| |
| |
| def test_function_graph_has_cycles_false(): |
| """Tests whether we catch a graph with cycles""" |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, config={'b': 1, 'c': 2}) |
| all_nodes = fg.get_nodes() |
| # checks it two ways |
| nodes = [n for n in all_nodes if not n.user_defined] |
| user_nodes = [n for n in all_nodes if n.user_defined] |
| assert fg.has_cycles(nodes, user_nodes) is False |
| # this is called by the driver |
| nodes, user_nodes = fg.get_upstream_nodes(['A', 'B', 'C']) |
| assert fg.has_cycles(nodes, user_nodes) is False |
| |
| |
| def test_function_graph_display(): |
| """Tests that display saves a file""" |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, config={'b': 1, 'c': 2}) |
| defined_nodes = set() |
| user_nodes = set() |
| for n in fg.get_nodes(): |
| if n.user_defined: |
| user_nodes.add(n) |
| else: |
| defined_nodes.add(n) |
| # hack of a test -- but it works... sort the lines and match them up. |
| # why? because for some reason given the same graph, the output file isn't deterministic. |
| expected = sorted(['// Dependency Graph\n', |
| 'digraph {\n', |
| '\tA [label=A]\n', |
| '\tC [label=C]\n', |
| '\tB [label=B]\n', |
| '\tc [label="UD: c"]\n', |
| '\tb [label="UD: b"]\n', |
| '\tb -> A\n', |
| '\tc -> A\n', |
| '\tA -> C\n', |
| '\tA -> B\n', |
| '}\n']) |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| path = tmp_dir.join('test.dot') |
| fg.display(defined_nodes, user_nodes, str(path), {'view': False}) |
| with open(str(path), 'r') as dot_file: |
| actual = sorted(dot_file.readlines()) |
| assert actual == expected |
| |
| |
| def test_create_graphviz_graph(): |
| """Tests that we create a graphviz graph""" |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, config={}) |
| nodes, user_nodes = fg.get_upstream_nodes(['A', 'B', 'C']) |
| # hack of a test -- but it works... sort the lines and match them up. |
| # why? because for some reason given the same graph, the output file isn't deterministic. |
| expected = sorted(['// test-graph', |
| 'digraph {', |
| '\tgraph [ratio=1]', |
| '\tB [label=B]', |
| '\tA [label=A]', |
| '\tc [label=c]', |
| '\tC [label=C]', |
| '\tb [label=b]', |
| '\tb [label="UD: b"]', |
| '\tc [label="UD: c"]', |
| '\tA -> B', |
| '\tb -> A', |
| '\tc -> A', |
| '\tA -> C', |
| '}', |
| '']) |
| if '' in expected: |
| expected.remove('') |
| digraph = graph.create_graphviz_graph(nodes, user_nodes, 'test-graph', dict(graph_attr={'ratio': '1'})) |
| actual = sorted(str(digraph).split('\n')) |
| if '' in actual: |
| actual.remove('') |
| assert actual == expected |
| |
| |
| def test_create_networkx_graph(): |
| """Tests that we create a networkx graph""" |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, config={}) |
| nodes, user_nodes = fg.get_upstream_nodes(['A', 'B', 'C']) |
| digraph = graph.create_networkx_graph(nodes, user_nodes, 'test-graph') |
| expected_nodes = sorted(['c', 'B', 'C', 'b', 'A']) |
| expected_edges = sorted([('c', 'A'), ('b', 'A'), ('A', 'B'), ('A', 'C')]) |
| assert sorted(list(digraph.nodes)) == expected_nodes |
| assert sorted(list(digraph.edges)) == expected_edges |
| |
| |
| def test_end_to_end_with_layered_decorators_resolves_true(): |
| fg = graph.FunctionGraph(tests.resources.layered_decorators, config={'foo': 'bar', 'd': 10, 'b': 20}) |
| out = fg.execute([n for n in fg.get_nodes()]) |
| assert len(out) > 0 # test config.when resolves correctly |
| assert out['e'] == (20 + 10) |
| assert out['f'] == (20 + 20) |
| |
| |
| def test_end_to_end_with_layered_decorators_resolves_false(): |
| config = {'foo': 'not_bar', 'd': 10, 'b': 20} |
| fg = graph.FunctionGraph(tests.resources.layered_decorators, config=config) |
| out = fg.execute([n for n in fg.get_nodes()], ) |
| assert {item: value for item, value in out.items() if item not in config} == {} |
| |
| |
| def test_combine_inputs_no_collision(): |
| """Tests the combine_and_validate_inputs functionality when there are no collisions""" |
| combined = graph.FunctionGraph.combine_config_and_inputs({'a': 1}, {'b': 2}) |
| assert combined == {'a': 1, 'b': 2} |
| |
| |
| def test_combine_inputs_collision(): |
| """Tests the combine_and_validate_inputs functionality |
| when there are collisions of keys but not values""" |
| with pytest.raises(ValueError): |
| graph.FunctionGraph.combine_config_and_inputs({'a': 1}, {'a': 2}) |
| |
| |
| def test_combine_inputs_collision_2(): |
| """Tests the combine_and_validate_inputs functionality |
| when there are collisions of keys and values""" |
| with pytest.raises(ValueError): |
| graph.FunctionGraph.combine_config_and_inputs({'a': 1}, {'a': 1}) |
| |
| |
| def test_extract_columns_executes_once(): |
| """Ensures that extract_columns only computes the function once. |
| Note this is a bit heavy-handed of a test but its nice to have.""" |
| fg = graph.FunctionGraph(tests.resources.extract_columns_execution_count, config={}) |
| unique_id = str(uuid.uuid4()) |
| fg.execute([n for n in fg.get_nodes()], inputs={'unique_id': unique_id}) |
| assert len(tests.resources.extract_columns_execution_count.outputs[unique_id]) == 1 # It should only be called once |
| |
| |
| class X(): |
| pass |
| |
| |
| class Y(X): |
| pass |
| |
| |
| custom_type = typing.TypeVar('FOOBAR') |
| |
| |
| @pytest.mark.parametrize('param_type,required_type,expected', [ |
| (custom_type, custom_type, True), |
| (custom_type, typing.TypeVar('FOO'), False), |
| (int, int, True), |
| (int, float, False), |
| (typing.List[int], typing.List, True), |
| (typing.List, typing.List[float], True), |
| (typing.List, list, True), |
| (typing.Dict, dict, True), |
| (dict, typing.Dict, True), |
| (list, typing.List, True), |
| (list, typing.List, True), |
| (typing.List[int], typing.List[float], False), |
| (typing.Dict, typing.List, False), |
| (typing.Mapping, typing.Dict, True), |
| (typing.Mapping, dict, True), |
| (dict, typing.Mapping, False), |
| (typing.Dict, typing.Mapping, False), |
| (typing.Iterable, typing.List, True), |
| (typing.Tuple[str, str], typing.Tuple[str, str], True), |
| (typing.Tuple[str, str], typing.Tuple[str], False), |
| (typing.Tuple[str, str], typing.Tuple, True), |
| (typing.Tuple, typing.Tuple[str, str], True), |
| (typing.Union[str, str], typing.Union[str, str], True), |
| (X, X, True), |
| (X, Y, True), |
| (Y, X, False), |
| ]) |
| def test_custom_subclass_check(param_type, required_type, expected): |
| """Tests the custom_subclass_check""" |
| actual = graph.custom_subclass_check(required_type, param_type) |
| assert actual == expected |
| |
| |
| class TestAdapter(base.SimplePythonDataFrameGraphAdapter): |
| @staticmethod |
| def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool: |
| # fake equivalence function |
| return node_type == pd.Series and input_type == list |
| |
| |
| adapter = TestAdapter() |
| |
| |
| @pytest.mark.parametrize('adapter,param_type,required_type,expected', [ |
| (None, typing.TypeVar('FOO'), typing.TypeVar('BAR'), False), |
| (None, custom_type, custom_type, True), |
| (None, int, int, True), |
| (adapter, int, float, False), |
| (None, typing.Dict, typing.Any, True), |
| (None, X, X, True), |
| (None, X, Y, True), |
| (adapter, pd.Series, pd.Series, True), |
| (adapter, list, pd.Series, True), |
| (adapter, dict, pd.Series, False), |
| ]) |
| def test_types_match(adapter, param_type, required_type, expected): |
| """Tests the types_match function""" |
| actual = graph.types_match(adapter, param_type, required_type) |
| assert actual == expected |
| |
| |
| def test_end_to_end_with_generics(): |
| fg = graph.FunctionGraph(tests.resources.functions_with_generics, config={'b': {}, 'c': 1}) |
| results = fg.execute(fg.get_nodes()) |
| assert results == { |
| 'A': ({}, 1), |
| 'B': [({}, 1), ({}, 1)], |
| 'C': {'foo': [({}, 1), ({}, 1)]}, |
| 'D': 1.0, |
| 'b': {}, |
| 'c': 1 |
| } |
| |
| |
| @pytest.mark.parametrize( |
| 'config,inputs,overrides', |
| [ |
| # testing with no provided inputs |
| ({}, {}, {}), |
| # testing with just configs |
| ({'a': 11}, {}, {}), |
| ({'b': 13, 'a': 17}, {}, {}), |
| ({'b': 19, 'a': 23, 'd': 29, 'f': 31}, {}, {}), |
| # Testing with just inputs |
| ({}, {'a': 37}, {}), |
| ({}, {'b': 41, 'a': 43}, {}), |
| ({}, {'b': 41, 'a': 43, 'd': 47, 'f': 53}, {}), |
| # Testing with just overrides |
| # TBD whether these should be legitimate -- can we override required inputs? |
| # Test works now but not part of the contract |
| # ({}, {}, {'a': 59}), |
| # ({}, {}, {'a': 61, 'b': 67}), |
| # ({}, {}, {'a': 71, 'b': 73, 'd': 79, 'f': 83}), |
| # testing with a mix |
| ({'a': 89}, {'b': 97}, {}), |
| ({'a': 101}, {'b': 103, 'd': 107}, {}), |
| ] |
| |
| ) |
| def test_optional_execute(config, inputs, overrides): |
| """Tests execution of optionals with different assortment of overrides, configs, inputs, etc... |
| Be careful adding tests with conflicting values between them. |
| """ |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config=config) |
| # we put a user input node first to ensure that order does not matter with computation order. |
| results = fg.execute([fg.nodes['b'], fg.nodes['g']], inputs=inputs, overrides=overrides) |
| do_all_args = {key + '_val': val for key, val in {**config, **inputs, **overrides}.items()} |
| expected_results = tests.resources.optional_dependencies._do_all(**do_all_args) |
| assert results['g'] == expected_results['g'] |
| |
| |
| def test_optional_input_behavior(): |
| """Tests that if we request optional user inputs that are not provided, we do not break. And if they are |
| we do the right thing and return them. |
| """ |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) |
| # nothing passed, so nothing returned |
| result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={}) |
| assert result == {} |
| # something passed, something returned via config |
| fg2 = graph.FunctionGraph(tests.resources.optional_dependencies, config={'a': 10}) |
| result = fg2.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={}) |
| assert result == {'a': 10} |
| # something passed, something returned via inputs |
| result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={'a': 10}, overrides={}) |
| assert result == {'a': 10} |
| # something passed, something returned via overrides |
| result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={'a': 10}) |
| assert result == {'a': 10} |
| |
| |
| @pytest.mark.parametrize('node_order', list(permutations('fhi'))) |
| def test_user_input_breaks_if_required_missing(node_order): |
| """Tests that we break because `h` is required but is not passed in.""" |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) |
| permutation = [ |
| fg.nodes[n] for n in node_order |
| ] |
| with pytest.raises(NotImplementedError): |
| fg.execute(permutation, inputs={}, overrides={}) |
| |
| |
| @pytest.mark.parametrize('node_order', list(permutations('fhi'))) |
| def test_user_input_does_not_break_if_required_provided(node_order): |
| """Tests that things work no matter the order because `h` is required and is passed in, while `f` is optional.""" |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': 10}) |
| permutation = [ |
| fg.nodes[n] for n in node_order |
| ] |
| result = fg.execute(permutation, inputs={}, overrides={}) |
| assert result == {'h': 10, 'i': 17} |
| |
| |
| def test_optional_donot_drop_none(): |
| """We do not want to drop `None` results from functions. We want to pass them through to the function. |
| |
| This is here to enshrine the current behavior. |
| """ |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': None}) |
| # enshrine behavior that None is not removed from being passed to the function. |
| results = fg.execute([fg.nodes['h'], fg.nodes['i']], inputs={}, overrides={}) |
| assert results == {'h': None, 'i': 17} |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) |
| results = fg.execute([fg.nodes['j'], fg.nodes['none_result'], fg.nodes['f']], inputs={}, overrides={}) |
| assert results == {'j': None, 'none_result': None} # f omitted cause it's optional. |
| |
| |
| def test_optional_get_required_compile_time(): |
| """Tests that getting required with optionals at compile time returns everything |
| TODO -- change this to be testing a different function (compile time) than runtime. |
| """ |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) |
| all_upstream, user_required = fg.get_upstream_nodes(['g']) |
| assert len(all_upstream) == 7 # 6 total nodes upstream |
| assert len(user_required) == 4 # 4 nodes required input |
| |
| |
| def test_optional_get_required_runtime(): |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) |
| all_upstream, user_required = fg.get_upstream_nodes(['g'], runtime_inputs={}) # Nothng required |
| assert len(all_upstream) == 3 # 6 total nodes upstream |
| assert len(user_required) == 0 # 4 nodes required input |
| |
| |
| def test_optional_get_required_runtime_with_provided(): |
| fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) |
| all_upstream, user_required = fg.get_upstream_nodes(['g'], runtime_inputs={'b': 109}) # Nothng required |
| assert len(all_upstream) == 4 # 6 total nodes upstream |
| assert len(user_required) == 1 # 4 nodes required input |
| |
| |
| def test_in_driver_function_definitions(): |
| """Tests that we can instantiate a DAG with a function defined in the driver, e.g. notebook context""" |
| def my_function(A: int, b: int, c: int) -> int: |
| """Function for input below""" |
| return A + b + c |
| f_module = ad_hoc_utils.create_temporary_module(my_function) |
| fg = graph.FunctionGraph(tests.resources.dummy_functions, f_module, config={'b': 3, 'c': 1}) |
| results = fg.execute([n for n in fg.get_nodes() if n.name in ['my_function', 'A']]) |
| assert results == {'A': 4, 'b': 3, 'c': 1, 'my_function': 8} |