blob: 54cb61c032614ee3bec6b8bc02a30b1bcb23265c [file] [log] [blame]
import logging
from typing import Any
import pandas as pd
import pytest
from hamilton import base
from hamilton.driver import Driver
from hamilton.experimental import h_cache
from tests import nodes
@pytest.fixture
def df():
return pd.DataFrame(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
def _read_write_df_test(df, filepath, read_fn, write_fn):
write_fn(df, filepath, "test")
pd.testing.assert_frame_equal(read_fn(pd.DataFrame(), filepath), df)
def test_feather(df, tmp_path):
_read_write_df_test(df, tmp_path / "file.feather", h_cache.read_feather, h_cache.write_feather)
def test_parquet(df, tmp_path):
_read_write_df_test(df, tmp_path / "file.parquet", h_cache.read_parquet, h_cache.write_parquet)
def test_pickle(df, tmp_path):
_read_write_df_test(df, tmp_path / "file.pkl", h_cache.read_pickle, h_cache.write_pickle)
def test_json(tmp_path):
data = {
"a": 1,
"b": 2,
}
filepath = tmp_path / "file.json"
h_cache.write_json(data, filepath, "test")
assert h_cache.read_json({}, filepath) == data
def test_unknown_format(tmp_path):
cache_path = str(tmp_path)
adapter = h_cache.CachingGraphAdapter(cache_path, base.DictResult())
data = {"a": 1, "b": 2}
filepath = tmp_path / "file.xyz"
with pytest.raises(ValueError) as err:
adapter._write_cache("xyz", data, filepath, "test")
assert str(err.value) == "invalid cache format: xyz"
with pytest.raises(ValueError) as err:
adapter._read_cache("xyz", None, filepath)
assert str(err.value) == "invalid cache format: xyz"
def test_init_default_readers_writers(tmp_path):
cache_path = str(tmp_path / "data")
adapter = h_cache.CachingGraphAdapter(cache_path, base.DictResult())
assert adapter.writers == {
"json": h_cache.write_json,
"feather": h_cache.write_feather,
"parquet": h_cache.write_parquet,
"pickle": h_cache.write_pickle,
}
assert adapter.readers == {
"json": h_cache.read_json,
"feather": h_cache.read_feather,
"parquet": h_cache.read_parquet,
"pickle": h_cache.read_pickle,
}
def read_str(expected_type: Any, filepath: str) -> str:
with open(filepath, "r", encoding="utf8") as file:
return file.read()
def write_str(data: str, filepath: str, name: str) -> None:
with open(filepath, "w", encoding="utf8") as file:
file.write(data)
def test_caching(tmp_path, caplog):
caplog.set_level(logging.INFO)
cache_path = str(tmp_path)
adapter = h_cache.CachingGraphAdapter(
cache_path,
base.DictResult(),
readers={"str": read_str},
writers={"str": write_str},
)
dr = Driver(
{"initial": "Hello, World!"},
nodes,
adapter=adapter,
)
assert dr.execute(["both"])["both"] == {
"lower": "hello, world!",
"upper": "HELLO, WORLD!",
}
assert {"lowercased", "uppercased", "both"} == {rec.message for rec in caplog.records}
assert read_str("", str(tmp_path / "lowercased.str")) == "hello, world!"
assert read_str("", str(tmp_path / "uppercased.str")) == "HELLO, WORLD!"
assert (
read_str("", str(tmp_path / "both.json"))
== '{"lower": "hello, world!", "upper": "HELLO, WORLD!"}'
)
caplog.clear()
assert dr.execute(["both"])["both"] == {
"lower": "hello, world!",
"upper": "HELLO, WORLD!",
}
assert not {rec.message for rec in caplog.records}
caplog.clear()
# now we force-compute one of the dependencies
adapter = h_cache.CachingGraphAdapter(
cache_path,
base.DictResult(),
readers={"str": read_str},
writers={"str": write_str},
force_compute={"lowercased"},
)
dr = Driver(
{"initial": "Hello, World!"},
nodes,
adapter=adapter,
)
assert dr.execute(["both"])["both"] == {
"lower": "hello, world!",
"upper": "HELLO, WORLD!",
}
# both the forced node and the node dependent on it is computed
assert {"lowercased", "both"} == {rec.message for rec in caplog.records}
def test_dispatch(tmp_path, caplog):
caplog.set_level(logging.INFO)
cache_path = str(tmp_path)
adapter = h_cache.CachingGraphAdapter(
cache_path,
base.PandasDataFrameResult(),
)
dr = Driver(
{},
nodes,
adapter=adapter,
)
actual = dr.execute(["combined"])
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
pd.testing.assert_frame_equal(actual, expected)
assert {"json df", "json series", "parquet df", "parquet series", "combined"} == {
rec.message for rec in caplog.records
}
pd.testing.assert_frame_equal(
h_cache.read_json(pd.DataFrame(), str(tmp_path / "my_df.json")), expected[["a", "b"]]
)
pd.testing.assert_series_equal(
h_cache.read_json(pd.Series(), str(tmp_path / "my_series.json")),
pd.Series([7, 8, 9], name="my_series"),
)
pd.testing.assert_frame_equal(
h_cache.read_parquet(pd.DataFrame(), str(tmp_path / "my_df2.parquet")), expected[["a", "b"]]
)
pd.testing.assert_series_equal(
h_cache.read_parquet(pd.Series(), str(tmp_path / "my_series2.parquet")),
pd.Series([7, 8, 9], name="my_series2"),
)
caplog.clear()
actual = dr.execute(["combined"])
pd.testing.assert_frame_equal(actual, expected)
assert {"combined"} == {rec.message for rec in caplog.records}
caplog.clear()
# now we force-compute one of the dependencies
adapter = h_cache.CachingGraphAdapter(
cache_path,
base.PandasDataFrameResult(),
force_compute={"my_df"},
)
dr = Driver(
{},
nodes,
adapter=adapter,
)
actual = dr.execute(["combined"])
pd.testing.assert_frame_equal(actual, expected)
# both the forced node and the node dependent on it is computed
assert {"json df", "parquet df", "combined"} == {rec.message for rec in caplog.records}