blob: a55cb0c453169e2c2b57ace95ac5fc758d772107 [file] [log] [blame]
import inspect
import pathlib
import shelve
import pytest
from hamilton import graph_types, node
from hamilton.lifecycle.default import CacheAdapter
def _callable_to_node(callable) -> node.Node:
return node.Node(
name=callable.__name__,
typ=inspect.signature(callable).return_annotation,
callabl=callable,
)
@pytest.fixture()
def hook(tmp_path: pathlib.Path):
return CacheAdapter(cache_path=str((tmp_path / "cache.db").resolve()))
@pytest.fixture()
def node_a():
"""Default function implementation"""
def A(external_input: int) -> int:
return external_input % 7
return _callable_to_node(A)
@pytest.fixture()
def node_a_body():
"""The function A() has modulo 5 instead of 7"""
def A(external_input: int) -> int:
return external_input % 5
return _callable_to_node(A)
@pytest.fixture()
def node_a_docstring():
"""The function A() has a docstring"""
def A(external_input: int) -> int:
"""This one has a docstring"""
return external_input % 7
return _callable_to_node(A)
def test_set_result(hook: CacheAdapter, node_a: node.Node):
"""Hook sets value and assert value in cache"""
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_kwargs = dict(external_input=7)
cache_key = CacheAdapter.create_key(node_hash, node_kwargs)
result = 2
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
hook.cache_vars = [node_a.name]
# used_nodes_hash would be set by run_to_execute() hook
hook.used_nodes_hash[node_a.name] = node_hash
hook.run_after_node_execution(
node_name=node_a.name,
node_kwargs=node_kwargs,
result=result,
)
# run_to_execute_node() hook would get cache
assert hook.cache.get(key=cache_key) == result
def test_get_result(hook: CacheAdapter, node_a: node.Node):
"""Hooks get value and assert cache hit"""
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_kwargs = dict(external_input=7)
result = 2
cache_key = CacheAdapter.create_key(node_hash, node_kwargs)
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
hook.cache_vars = [node_a.name]
# run_after_node_execution() would set cache
hook.cache[cache_key] = result
retrieved = hook.run_to_execute_node(
node_name=node_a.name,
node_kwargs=node_kwargs,
node_callable=node_a.callable,
)
assert retrieved == result
def test_append_nodes_history(
hook: CacheAdapter,
node_a: node.Node,
node_a_body: node.Node,
):
"""Assert the CacheHook.nodes_history is growing;
doesn't check for commit to cache
"""
node_name = "A"
node_kwargs = dict(external_input=7)
hook.cache_vars = [node_a.name]
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
# node version 1
hook.used_nodes_hash[node_name] = graph_types.hash_source_code(node_a.callable, strip=True)
hook.run_to_execute_node(
node_name=node_name,
node_kwargs=node_kwargs,
node_callable=node_a.callable,
)
# check history
assert len(hook.nodes_history.get(node_name, [])) == 1
# node version 2
hook.used_nodes_hash[node_name] = graph_types.hash_source_code(node_a_body.callable, strip=True)
hook.run_to_execute_node(
node_name=node_name,
node_kwargs=node_kwargs,
node_callable=node_a_body.callable,
)
assert len(hook.nodes_history.get(node_name, [])) == 2
def test_commit_nodes_history(hook: CacheAdapter):
"""Commit node history to cache"""
hook.nodes_history = dict(A=["hash_1", "hash_2"])
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
hook.run_after_graph_execution()
# need to reopen the hook cache
with shelve.open(hook.cache_path) as cache:
assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history