blob: 0c2456065efdeca1bf38381f135fde969a6be10d [file] [log] [blame]
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple, Type
import pytest
from hamilton import node
from hamilton.graph import FunctionGraph
from hamilton.lifecycle.base import (
ASYNC_HOOK,
ASYNC_METHOD,
REGISTERED_ASYNC_HOOKS,
REGISTERED_ASYNC_METHODS,
REGISTERED_SYNC_HOOKS,
REGISTERED_SYNC_METHODS,
REGISTERED_SYNC_VALIDATORS,
SYNC_HOOK,
SYNC_METHOD,
SYNC_VALIDATOR,
BaseDoNodeExecute,
BaseDoValidateInput,
BasePostGraphExecute,
BasePreDoAnythingHook,
BaseValidateGraph,
BaseValidateNode,
InvalidLifecycleAdapter,
LifecycleAdapterSet,
ValidationResult,
lifecycle,
validate_lifecycle_adapter_function,
)
from tests.lifecycle.lifecycle_adapters_for_testing import ExtendToTrackCalls
def _valid_function_empty():
pass
def _valid_function_returns_value() -> int:
return 1
def _valid_function_self(self):
pass
def _valid_function_self_kwargs(self, *, a: int, b: int) -> int:
return a + b
@pytest.mark.parametrize(
"fn,returns_value,specified_return_value",
[
(_valid_function_empty, False, None),
(_valid_function_returns_value, True, None),
(_valid_function_returns_value, True, int),
(_valid_function_self, False, None),
(_valid_function_self_kwargs, True, None),
],
)
def test_validate_lifecycle_adapter_function_success(
fn, returns_value: bool, specified_return_value: Optional[Type]
):
"""Test that the lifecycle adapter function works as expected."""
validate_lifecycle_adapter_function(
fn, returns_value=returns_value, return_type=specified_return_value
)
def _function_with_positional_args(a, b):
return a + b
def _function_with_mixed_args(self, a, b=None):
pass
def _function_with_no_return_annotation():
return 42
def _function_with_return_annotation() -> int:
return 42
# Test cases
@pytest.mark.parametrize(
"fn, returns_value",
[
(_valid_function_empty, True),
(_valid_function_returns_value, False),
(_valid_function_self, True),
(_valid_function_self_kwargs, False),
(_function_with_positional_args, True),
(_function_with_mixed_args, False),
(_function_with_no_return_annotation, True),
(_function_with_return_annotation, False),
],
)
def test__validate_lifecycle_adapter_function_failure(fn, returns_value: bool):
"""Test that the lifecycle adapter function fails as expected for invalid cases."""
with pytest.raises(InvalidLifecycleAdapter):
validate_lifecycle_adapter_function(fn, returns_value=returns_value)
def test_base_hook_decorator():
@lifecycle.base_hook("hook_for_testing")
class TestHook:
def hook_for_testing(self):
pass
assert getattr(TestHook, SYNC_HOOK) == "hook_for_testing"
assert "hook_for_testing" in REGISTERED_SYNC_HOOKS
def test_base_hook_decorator_async():
@lifecycle.base_hook("async_hook_for_testing")
class TestHookAsync:
async def async_hook_for_testing(self):
pass
assert getattr(TestHookAsync, ASYNC_HOOK) == "async_hook_for_testing"
assert "async_hook_for_testing" in REGISTERED_ASYNC_HOOKS
def test_base_method_decorator():
@lifecycle.base_method("method_for_testing")
class TestMethod:
def method_for_testing(self) -> int:
pass
assert getattr(TestMethod, SYNC_METHOD) == "method_for_testing"
assert "method_for_testing" in REGISTERED_SYNC_METHODS
def test_base_method_decorator_async():
@lifecycle.base_method("async_method_for_testing")
class TestMethodAsync:
async def async_method_for_testing(self) -> int:
pass
assert getattr(TestMethodAsync, ASYNC_METHOD) == "async_method_for_testing"
assert "async_method_for_testing" in REGISTERED_ASYNC_METHODS
def test_base_validator_decorator():
@lifecycle.base_validator("sync_validator")
class TestValidatorSync:
def sync_validator(self) -> Tuple[bool, Optional[str]]:
return True, None
assert getattr(TestValidatorSync, SYNC_VALIDATOR) == "sync_validator"
assert "sync_validator" in REGISTERED_SYNC_VALIDATORS
def test_lifecycle_adapter_set_with_multiple_hooks():
class MockHook1(BasePreDoAnythingHook, ExtendToTrackCalls):
def pre_do_anything(self):
pass
class MockHook2(BasePreDoAnythingHook, ExtendToTrackCalls):
def pre_do_anything(self):
pass
class MockHook3(BasePostGraphExecute, ExtendToTrackCalls):
def post_graph_execute(
self,
*,
run_id: str,
graph: FunctionGraph,
success: bool,
error: Optional[Exception],
results: Optional[Dict[str, Any]],
):
pass
hook_1 = MockHook1("mock_hook_1")
hook_2 = MockHook2("mock_hook_2")
hook_3 = MockHook3("mock_hook_3")
adapter_set = LifecycleAdapterSet(hook_1, hook_2, hook_3)
assert adapter_set.does_hook("pre_do_anything", is_async=False)
assert adapter_set.does_hook("post_graph_execute", is_async=False)
assert not adapter_set.does_hook("post_graph_execute", is_async=True)
assert not adapter_set.does_hook("pre_node_execute", is_async=False)
assert not adapter_set.does_hook("pre_node_execute", is_async=True)
adapter_set.call_all_lifecycle_hooks_sync("pre_do_anything")
assert len(hook_1.calls) == 1
assert len(hook_2.calls) == 1
assert len(hook_3.calls) == 0
def test_lifecycle_adapter_set_with_single_multi_hook():
class MockMultiHook(BasePreDoAnythingHook, BasePostGraphExecute, ExtendToTrackCalls):
def pre_do_anything(self):
pass
def post_graph_execute(
self,
*,
run_id: str,
graph: "FunctionGraph",
success: bool,
error: Optional[Exception],
results: Optional[Dict[str, Any]],
):
pass
multi_hook = MockMultiHook("mock_multi_hook")
adapter_set = LifecycleAdapterSet(multi_hook)
assert adapter_set.does_hook("pre_do_anything", is_async=False)
assert adapter_set.does_hook("post_graph_execute", is_async=False)
assert not adapter_set.does_hook("pre_node_execute", is_async=False)
assert not adapter_set.does_hook("pre_node_execute", is_async=True)
adapter_set.call_all_lifecycle_hooks_sync("pre_do_anything")
assert len(multi_hook.calls) == 1 # two pre do anything calls
def test_lifecycle_adapter_set_with_multiple_methods():
class MockMethod1(BaseDoNodeExecute, ExtendToTrackCalls):
def do_node_execute(
self,
*,
run_id: str,
node_: node.Node,
kwargs: Dict[str, Any],
task_id: Optional[str] = None,
) -> Any:
return 1
class MockMethod2(BaseDoValidateInput, ExtendToTrackCalls):
def do_validate_input(self, *, node_type: type, input_value: Any) -> bool:
return False
method_1 = MockMethod1("mock_method_1")
method_2 = MockMethod2("mock_method_2")
adapter_set = LifecycleAdapterSet(
method_1,
method_2,
)
assert adapter_set.does_method("do_node_execute", is_async=False)
assert adapter_set.does_method("do_validate_input", is_async=False)
assert (
adapter_set.call_lifecycle_method_sync(
"do_node_execute", run_id="a", node_=None, kwargs={}, task_id=None
)
== 1
)
assert len(method_1.calls) == 1
adapter_set.call_lifecycle_method_sync("do_validate_input", node_type=None, input_value=None)
assert len(method_2.calls) == 1
def test_lifecycle_adapter_set_with_single_multi_method():
class MockMultiMethod(BaseDoNodeExecute, BaseDoValidateInput, ExtendToTrackCalls):
def do_node_execute(
self,
*,
run_id: str,
node_: node.Node,
kwargs: Dict[str, Any],
task_id: Optional[str] = None,
) -> Any:
return 1
def do_validate_input(self, *, node_type: type, input_value: Any) -> bool:
return True
multi_hook = MockMultiMethod("mock_multi_hook")
adapter_set = LifecycleAdapterSet(multi_hook)
assert adapter_set.does_method("do_node_execute", is_async=False)
assert adapter_set.does_method("do_validate_input", is_async=False)
assert not adapter_set.does_method("do_node_execute", is_async=True)
assert (
adapter_set.call_lifecycle_method_sync(
"do_node_execute", run_id="a", node_=None, kwargs={}, task_id=None
)
== 1
)
assert len(multi_hook.calls) == 1
adapter_set.call_lifecycle_method_sync("do_validate_input", node_type=None, input_value=None)
assert len(multi_hook.calls) == 2
def test_lifecycle_adapter_set_with_multiple_validators():
null_adapter_set = LifecycleAdapterSet()
assert not null_adapter_set.does_validation("validate_node")
assert not null_adapter_set.does_validation("validate_graph")
class MockNodeValidator(BaseValidateNode, ExtendToTrackCalls):
def validate_node(self, *, created_node: node.Node):
return False, "Validating node"
class MockGraphValidator(BaseValidateGraph, ExtendToTrackCalls):
def validate_graph(
self, *, graph: "FunctionGraph", modules: List[ModuleType], config: Dict[str, Any]
) -> Tuple[bool, Optional[str]]:
return False, "Validating graph"
adapter_set = LifecycleAdapterSet(
node_validator := MockNodeValidator("mock_node_validator"),
graph_validator := MockGraphValidator("mock_graph_validator"),
)
assert adapter_set.does_validation("validate_node")
assert adapter_set.does_validation("validate_graph")
assert adapter_set.call_all_validators_sync("validate_node", created_node=None) == [
ValidationResult(success=False, error="Validating node", validator=node_validator)
]
assert adapter_set.call_all_validators_sync(
"validate_graph", graph=None, modules=[], config={}
) == [ValidationResult(success=False, error="Validating graph", validator=graph_validator)]
assert len(node_validator.calls) == 1
assert len(graph_validator.calls) == 1
def test_lifecycle_adapter_set_with_single_multi_validator():
null_adapter_set = LifecycleAdapterSet()
assert not null_adapter_set.does_validation("validate_node")
assert not null_adapter_set.does_validation("validate_graph")
class MockMultiValidator(BaseValidateNode, BaseValidateGraph, ExtendToTrackCalls):
def validate_node(self, *, created_node: node.Node) -> Tuple[bool, Optional[str]]:
return False, "Validating node"
def validate_graph(
self, *, graph: "FunctionGraph", modules: List[ModuleType], config: Dict[str, Any]
) -> Tuple[bool, Optional[str]]:
return False, "Validating graph"
adapter_set = LifecycleAdapterSet(
multi_validator := MockMultiValidator("mock_multi_validator"),
)
assert adapter_set.does_validation("validate_node")
assert adapter_set.does_validation("validate_graph")
assert adapter_set.call_all_validators_sync("validate_node", created_node=None) == [
ValidationResult(success=False, error="Validating node", validator=multi_validator)
]
assert adapter_set.call_all_validators_sync(
"validate_graph", graph=None, modules=[], config={}
) == [ValidationResult(success=False, error="Validating graph", validator=multi_validator)]
assert len(multi_validator.calls) == 2