blob: 0e6078bb6a07cb6c642b6d331db567534b3738af [file]
from typing import Dict
import pytest
from hamilton.caching.cache_key import create_cache_key
from hamilton.caching.fingerprinting import hash_value
from hamilton.caching.stores.memory import InMemoryMetadataStore
from hamilton.caching.stores.sqlite import SQLiteMetadataStore
from hamilton.graph_types import hash_source_code
# if you're adding a new `MetadataStore` implementation, add it to this list.
# Your implementation should successfully pass all the tests.
# Implementation-specific tests should be added to `test_{implementation}.py`.
IMPLEMENTATIONS = [SQLiteMetadataStore, InMemoryMetadataStore]
def _instantiate_metadata_store(metadata_store_cls, tmp_path):
if metadata_store_cls == SQLiteMetadataStore:
return SQLiteMetadataStore(path=tmp_path)
elif metadata_store_cls == InMemoryMetadataStore:
return InMemoryMetadataStore()
else:
raise ValueError(
f"Class `{metadata_store_cls}` isn't defined in `_instantiate_metadata_store()`"
)
def _mock_cache_key(
node_name: str = "foo",
code_version: str = "FOO-1",
dependencies_data_versions: Dict[str, str] = None,
) -> str:
"""Utility to create a valid cache key from mock values.
This is helpful because ``code_version`` and ``data_version`` found in ``dependencies_data_versions``
must respect specific encoding.
"""
dependencies_data_versions = (
dependencies_data_versions if dependencies_data_versions is not None else {}
)
return create_cache_key(
node_name=node_name,
code_version=hash_source_code(code_version),
dependencies_data_versions={k: hash_value(v) for k, v in dependencies_data_versions},
)
@pytest.fixture
def metadata_store(request, tmp_path):
metadata_store_cls = request.param
metadata_store = _instantiate_metadata_store(metadata_store_cls, tmp_path)
yield metadata_store
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_initialize_empty(metadata_store):
metadata_store.initialize(run_id="test-run-id")
assert metadata_store.size == 0
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_not_empty_after_set(metadata_store):
cache_key = _mock_cache_key()
run_id = "test-run-id"
metadata_store.initialize(run_id=run_id)
metadata_store.set(
cache_key=cache_key,
data_version="foo-a",
run_id=run_id,
)
assert metadata_store.size > 0
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_set_doesnt_produce_duplicates(metadata_store):
cache_key = _mock_cache_key()
data_version = "foo-a"
run_id = "test-run-id"
metadata_store.initialize(run_id=run_id)
metadata_store.set(
cache_key=cache_key,
data_version=data_version,
run_id=run_id,
)
assert metadata_store.size == 1
metadata_store.set(
cache_key=cache_key,
data_version=data_version,
run_id=run_id,
)
assert metadata_store.size == 1
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_get_miss_returns_none(metadata_store):
cache_key = _mock_cache_key()
run_id = "test-run-id"
metadata_store.initialize(run_id=run_id)
data_version = metadata_store.get(cache_key=cache_key)
assert data_version is None
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_set_and_get_with_empty_dependencies(metadata_store):
cache_key = _mock_cache_key()
data_version = "foo-a"
run_id = "test-run-id"
metadata_store.initialize(run_id=run_id)
metadata_store.set(
cache_key=cache_key,
data_version=data_version,
run_id=run_id,
)
retrieved_data_version = metadata_store.get(cache_key=cache_key)
assert retrieved_data_version == data_version
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_get_run_ids_returns_ordered_list(metadata_store):
pre_run_ids = metadata_store.get_run_ids()
assert pre_run_ids == []
metadata_store.initialize(run_id="foo")
metadata_store.initialize(run_id="bar")
metadata_store.initialize(run_id="baz")
post_run_ids = metadata_store.get_run_ids()
assert post_run_ids == ["foo", "bar", "baz"]
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_get_run_results_include_cache_key_and_data_version(metadata_store):
cache_key = _mock_cache_key()
data_version = "foo-a"
run_id = "test-run-id"
metadata_store.initialize(run_id=run_id)
metadata_store.set(
cache_key=cache_key,
data_version=data_version,
run_id=run_id,
)
run_info = metadata_store.get_run(run_id=run_id)
assert isinstance(run_info, list)
assert len(run_info) == 1
assert isinstance(run_info[0], dict)
assert run_info[0]["cache_key"] == cache_key
assert run_info[0]["data_version"] == data_version
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store):
run_id = "test-run-id"
metadata_store.initialize(run_id=run_id)
run_info = metadata_store.get_run(run_id=run_id)
assert run_info == []
@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_get_run_raises_error_if_run_id_not_found(metadata_store):
metadata_store.initialize(run_id="test-run-id")
with pytest.raises(IndexError):
metadata_store.get_run(run_id="foo")