blob: a151c4ecc2cd1f95f31871dfb07fe5512c8f2ea7 [file] [log] [blame]
import dataclasses
import functools
import inspect
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple
from hamilton import node
from hamilton.graph import FunctionGraph
from hamilton.lifecycle.base import (
BaseDoBuildResult,
BaseDoNodeExecute,
BaseDoValidateInput,
BasePostGraphConstruct,
BasePostGraphExecute,
BasePostNodeExecute,
BasePostTaskExecute,
BasePreDoAnythingHook,
BasePreGraphExecute,
BasePreNodeExecute,
BasePreTaskExecute,
BaseValidateGraph,
BaseValidateNode,
LifecycleAdapterSet,
)
from hamilton.node import Node
@dataclasses.dataclass
class HookCall:
sequence_number: int
name: str
fn: Callable
bound_kwargs: Dict[str, Any]
result: Any
class SentinelException(Exception):
pass
# Allows us to track the order in which multiple lifecycle hooks are called
# This ensures that we can get order of call across operations
# We'll want to test for order, not absolute position, as it is global
sequence_number = 0
# This is slightly hacky, but its a quick way to decorate the lifecycle hooks
class ExtendToTrackCalls:
def __init__(self, name: str):
# We just pass in self, cause why not?
adapter_set = LifecycleAdapterSet(self)
for lifecycle_step, _ in {
**adapter_set.sync_methods,
**adapter_set.async_methods,
**adapter_set.sync_hooks,
**adapter_set.async_hooks,
**adapter_set.sync_validators,
}.items():
# We know there's just one in this case
setattr(self, lifecycle_step, self._wrap_fn(getattr(self, lifecycle_step)))
self._calls = []
self._name = name
def _wrap_fn(self, fn: Callable):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
global sequence_number
sequence_number += 1
sig = inspect.signature(fn)
bound = sig.bind(*args, **kwargs)
out = fn(*args, **kwargs)
self._calls.append(HookCall(sequence_number, self._name, fn, bound.arguments, out))
return out
return wrapped
@property
def calls(self) -> List[HookCall]:
return self._calls
@property
def name(self):
return self._name
class TrackingPreDoAnythingHook(BasePreDoAnythingHook, ExtendToTrackCalls):
def pre_do_anything(self):
pass
class TrackingPostGraphConstructHook(ExtendToTrackCalls, BasePostGraphConstruct):
def post_graph_construct(
self, graph: FunctionGraph, modules: List[ModuleType], config: Dict[str, Any]
):
pass
class TrackingPreGraphExecuteHook(ExtendToTrackCalls, BasePreGraphExecute):
def pre_graph_execute(
self,
run_id: str,
graph: FunctionGraph,
final_vars: List[str],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
):
pass
class TrackingPreTaskExecuteHook(ExtendToTrackCalls, BasePreTaskExecute):
def pre_task_execute(
self,
run_id: str,
task_id: str,
nodes: List[node.Node],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
):
pass
class TrackingPreNodeExecuteHook(ExtendToTrackCalls, BasePreNodeExecute):
def pre_node_execute(
self, run_id: str, node_: Node, kwargs: Dict[str, Any], task_id: Optional[str] = None
):
pass
class TrackingPostNodeExecuteHook(ExtendToTrackCalls, BasePostNodeExecute):
def post_node_execute(
self,
run_id: str,
node_: Node,
kwargs: Dict[str, Any],
success: bool,
error: Optional[Exception],
result: Any,
task_id: Optional[str] = None,
):
pass
class TrackingPostTaskExecuteHook(ExtendToTrackCalls, BasePostTaskExecute):
def post_task_execute(
self,
run_id: str,
task_id: str,
nodes: List[node.Node],
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
):
pass
class TrackingPostGraphExecuteHook(ExtendToTrackCalls, BasePostGraphExecute):
def post_graph_execute(
self,
run_id: str,
graph: FunctionGraph,
success: bool,
error: Optional[Exception],
results: Optional[Dict[str, Any]],
):
pass
class TrackingDoValidateInputMethod(ExtendToTrackCalls, BaseDoValidateInput):
def __init__(self, name: str, valid: bool = True):
super().__init__(name)
self._valid = valid
def do_validate_input(self, node_type: type, input_value: Any) -> bool:
return self._valid
class TrackingDoNodeExecuteHook(ExtendToTrackCalls, BaseDoNodeExecute):
def __init__(self, name: str, additional_value: int):
super().__init__(name)
self._additional_value = additional_value
def do_node_execute(
self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None
) -> Any:
if node_.type == int and node_.name != "n_iters":
return node_(**kwargs) + self._additional_value
return node_(**kwargs)
class TrackingDoBuildResultMethod(ExtendToTrackCalls, BaseDoBuildResult):
def __init__(self, name: str, result: Any):
super().__init__(name)
self._result = result
def do_build_result(self, outputs: Dict[str, Any]) -> Any:
return self._result
class TrackingValidateNodeValidator(ExtendToTrackCalls, BaseValidateNode):
def __init__(self, name: str, valid: bool, message: Optional[str]):
super().__init__(name)
self._valid = valid
self._message = message
def validate_node(self, *, created_node: node.Node) -> Tuple[bool, Optional[str]]:
return self._valid, self._message
class TrackingValidateGraphValidator(ExtendToTrackCalls, BaseValidateGraph):
def __init__(self, name: str, valid: bool, message: Optional[str]):
super().__init__(name)
self._valid = valid
self._message = message
def validate_graph(
self, *, graph: "FunctionGraph", modules: List[ModuleType], config: Dict[str, Any]
) -> Tuple[bool, Optional[str]]:
return self._valid, self._message