WIP
diff --git a/hamilton/graph.py b/hamilton/graph.py index f5c4516..fabbad0 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py
@@ -474,7 +474,19 @@ import dataclasses import enum from collections import defaultdict, deque -from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Set, Tuple, Type +from typing import ( + Any, + Callable, + Collection, + Dict, + Generator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, +) from hamilton import base, node from hamilton.function_modifiers import base as fm_base @@ -711,7 +723,6 @@ # We update the assign group, as well as the node group map assign_group.nodes.append(current_node) node_group_map[current_node.name] = assign_group - print(f"Assigning {current_node.name} to {assign_group.id}") # now we go through dependencies and add them to the queue when they're ready for subsequent_node in current_node.depended_on_by: if ready(subsequent_node): @@ -720,7 +731,6 @@ def topologically_sort_node_groups(node_groups: List[NodeGroup]) -> List[NodeGroup]: - # Big n^2 hack for a quick workaround return node_groups # TODO -- actually sort these, or require that they come in topo-sorted order and delete this function @@ -733,6 +743,7 @@ required_outputs: List[str] dependencies: List[str] # dependencies on prior tasks. + adapter: base.HamiltonGraphAdapter = base.SimplePythonDataFrameGraphAdapter() def filter_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: # TODO -- validate that we have the inputs we need @@ -746,10 +757,53 @@ return overrides_filtered def execute_standard(self, kwargs: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: - # TODO -- implement this - import pdb + computed = {} + inputs = kwargs + overrides = {**inputs, **overrides} + adapter = self.adapter - pdb.set_trace() + def dfs_traverse( + node_: node.Node, dependency_type: node.DependencyType = node.DependencyType.REQUIRED + ): + if node_.name in computed: + return + if node_.name in overrides: + computed[node_.name] = overrides[node_.name] + return + for n in node_.dependencies: + if n.name not in computed: + _, node_dependency_type = node_.input_types[n.name] + dfs_traverse(n, node_dependency_type) + + if node_.user_defined: + if node_.name not in inputs: + if dependency_type != node.DependencyType.OPTIONAL: + raise NotImplementedError( + f"{node_.name} was expected to be passed in but was not." + ) + return + value = inputs[node_.name] + else: + kwargs = {} # construct signature + for dependency in node_.dependencies: + if dependency.name in computed: + kwargs[dependency.name] = computed[dependency.name] + try: + value = adapter.execute_node(node_, kwargs) + except Exception: + # logger.exception(f"Node {node_.name} encountered an error") + raise + computed[node_.name] = value + + nodes_by_name = {node_.name: node_ for node_ in self.nodes} + + for final_var_node in [nodes_by_name[output] for output in self.required_outputs]: + dep_type = node.DependencyType.REQUIRED + if final_var_node.user_defined: + # from the top level, we don't know if this UserInput is required. So mark as optional. + dep_type = node.DependencyType.OPTIONAL + dfs_traverse(final_var_node, dep_type) + return computed def execute_repeat_parallel( self, kwargs: Dict[str, Any], overrides: Dict[str, Any] @@ -759,12 +813,30 @@ def execute_repeat_sequential( self, kwargs: Dict[str, Any], overrides: Dict[str, Any] ) -> Dict[str, Any]: - # TODO -- implement this - pass + # Quick hack for grabbing the single node that is a generator, and turning the whole + # thing into a generator + # with nested pieces this absolutely won't work + if len(self.required_outputs) != 1: + raise ValueError(f"Expected a single output, got {self.required_outputs}") + (generator_node,) = [ + node_ for node_ in self.nodes if node_.node_type == NodeType.EXPAND_SEQUENTIAL + ] + generator_inputs = { + k: v + for k, v in kwargs.items() + if k in set(item.name for item in generator_node.dependencies) + } + + def task_generator() -> Generator[Dict[str, Any], None, None]: + for result in generator_node.callable(**generator_inputs): + subtask_overrides = {generator_node.name: result, **overrides} + yield self.execute_standard(kwargs, subtask_overrides)[self.required_outputs[0]] + + return {self.required_outputs[0]: task_generator()} def execute_collect(self, kwargs: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: # TODO -- implement this - pass + return self.execute_standard(kwargs, overrides) def execute(self, kwargs: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: # TODO -- implement this @@ -914,6 +986,8 @@ if computed is None: computed = {} + nodes_requested = {self.get_nodes()[n] for n in nodes} + nodes_to_execute, _ = self.get_upstream_nodes(nodes, inputs_resolved) # Let's determine the tasks to execute node_to_task_map = {} @@ -939,11 +1013,14 @@ for dep in deps: if dep in nodes_to_execute_in_task: outputs_to_compute.add(dep) + outputs_in_task = set(nodes_requested).intersection(set(node_group.nodes)) + # if node_group.id != "steps" and node_group.id != "root": + # import pdb + # pdb.set_trace() + for output in outputs_in_task: + outputs_to_compute.add(output) if len(outputs_to_compute) == 0: # We don't need to compute anything, so we can skip this task - import pdb - - pdb.set_trace() print("no outputs to compute", node_group) continue @@ -962,28 +1039,32 @@ source_tasks = [task for task in all_tasks if len(task.dependencies) == 0] task_queue = deque(source_tasks) executed = set() - reverse_dependency_map = defaultdict(list) + reverse_dependency_map = defaultdict(set) for task in all_tasks: for dep in task.dependencies: - reverse_dependency_map[dep].append(task) + reverse_dependency_map[dep].add(task.id) results = inputs_resolved while len(task_queue) > 0: current_task = task_queue.popleft() - print(f"Executing task {current_task}") - filtered_inputs = current_task.filter_inputs(results) - filtered_overrides = current_task.filter_overrides(overrides) + print(f"Executing task {current_task.id}") + filtered_inputs = results + filtered_overrides = overrides + # filtered_inputs = current_task.filter_inputs(results) + # filtered_overrides = current_task.filter_overrides(overrides) # we just need to execute, as execution will give us all current results, # including the generators shaped as they're expected task_results = current_task.execute(filtered_inputs, filtered_overrides) executed.add(current_task.id) results = {**results, **task_results} for up_next in reverse_dependency_map[current_task.id]: + # hacky but whatever, for now + (up_next,) = [item for item in all_tasks if item.id == up_next] for dependency in up_next.dependencies: if dependency not in executed: break task_queue.append(up_next) output_nodes = {n.name for n in nodes_to_execute} - return {key: value for key, value in results if key in output_nodes} + return {key: value for key, value in results.items() if key in output_nodes} # return FunctionGraph.execute_static( # nodes=nodes,
diff --git a/tests/resources/generators/sequential_linear.py b/tests/resources/generators/sequential_linear.py index aa592a7..7f06c56 100644 --- a/tests/resources/generators/sequential_linear.py +++ b/tests/resources/generators/sequential_linear.py
@@ -9,16 +9,19 @@ # expand def steps(number_of_steps: int) -> Sequential[int]: for i in range(number_of_steps): + print("yielding step {}".format(i)) yield i # process def step_squared(steps: int) -> int: + print("squaring step {}".format(steps)) return steps**2 # join def sum_step_squared(step_squared: Collect[int]) -> int: + print("summing step squared") out = 0 for step in step_squared: out += step @@ -27,6 +30,7 @@ # final def final(sum_step_squared: int) -> int: + print("finalizing") return sum_step_squared
diff --git a/tests/test_generators.py b/tests/test_generators.py index 63ecf80..96c7976 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py
@@ -31,7 +31,7 @@ assert len(groups) == 3 # One for the precursor, one for the repeat, and one for the collect fn_graph = graph.FunctionGraph(groups, [], {}) results = fn_graph.execute(["final"]) - print(results) + assert results["final"] == sequential_linear._calc() import pdb pdb.set_trace()