blob: 5c0043997d98783f0dc86d5bd7bfab2b1fb6e665 [file] [log] [blame]
from typing import Callable, Dict, List, Union
import pytest
from hamilton import node
from hamilton.execution.graph_functions import nodes_between, topologically_sort_nodes
def _create_dummy_dag(
adjacency_map: Dict[str, List[str]], dict_output: bool = False
) -> Union[List[node.Node], Dict[str, node.Node]]:
name_map = {}
for name, dependencies in adjacency_map.items():
input_types = {dep: object for dep in dependencies}
node_ = node.Node(
name=name,
typ=object,
callabl=lambda **kwargs: object(),
input_types=input_types,
)
name_map[name] = node_
nodes = []
for name, dependencies in adjacency_map.items():
node_ = name_map[name]
for dependency in dependencies:
dep = name_map[dependency]
dep.depended_on_by.append(node_)
node_.dependencies.append(name_map[dependency])
nodes.append(node_)
if dict_output:
return {node_.name: node_ for node_ in nodes}
return nodes
def _assert_topologically_sorted(nodes, sorted_nodes):
for node_ in nodes:
for dep in node_.dependencies:
assert sorted_nodes.index(dep.name) < sorted_nodes.index(node_.name)
@pytest.mark.parametrize(
"dag_input, expected_sorted_nodes",
[
({"a": [], "b": ["a"], "c": ["a"], "d": ["b", "c"], "e": ["d"]}, ["a", "b", "c", "d", "e"]),
({}, []),
({"a": []}, ["a"]),
(
{
"a": ["b", "c"],
"b": ["d", "e"],
"c": ["d", "e"],
"d": ["f"],
"e": ["f", "g"],
"f": ["h"],
"g": ["h", "i"],
"h": ["j", "k"],
"i": ["k", "l"],
"j": ["m"],
"k": ["m", "n"],
"l": ["n"],
"m": ["o"],
"n": ["o", "p"],
"o": ["q"],
"p": ["q"],
"q": [],
},
["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q"],
),
],
ids=[
"Simple DAG",
"Empty DAG",
"Single Node DAG",
"Large DAG",
],
)
def test_topologically_sort_nodes(dag_input, expected_sorted_nodes):
nodes = _create_dummy_dag(dag_input)
sorted_nodes = [item.name for item in topologically_sort_nodes(nodes)]
_assert_topologically_sorted(nodes, sorted_nodes)
def _is(name: str) -> Callable[[node.Node], bool]:
def _inner(n: node.Node) -> bool:
return n.name == name
return _inner
@pytest.mark.parametrize(
"dag_repr, expected_nodes_in_between, start_node, end_node",
[
(
{"a": [], "b": ["a"], "c": ["b"]},
{"b"},
"a",
"c",
),
(
{"a": [], "b": ["a"], "c": ["b"], "d": ["c"], "e": ["d"]},
{"b", "c", "d"},
"a",
"e",
),
(
{
"a": ["b", "c"],
"b": ["d"],
"c": ["d"],
"d": ["e"],
"e": ["f"],
"f": ["g"],
"g": ["h"],
"h": ["i", "j"],
"i": ["k"],
"j": ["k"],
"k": ["l"],
"l": ["m", "n"],
"m": ["o"],
"n": ["o"],
"o": ["p"],
"p": ["q"],
"q": [],
},
{"e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p"},
"q",
"d",
),
({"a": [], "b": [], "c": ["a", "b"], "d": "c"}, {"c"}, "a", "d"),
],
ids=["simple_base", "longer_chain", "complex_dag", "subdag_with_external_dep"],
)
def test_find_nodes_between(dag_repr, expected_nodes_in_between, start_node, end_node):
nodes = _create_dummy_dag(dag_repr, dict_output=True)
found_node, in_between = nodes_between(nodes[end_node], _is(start_node))
assert found_node.name == start_node
assert set(in_between) == {nodes[item] for item in expected_nodes_in_between}