blob: 0d4e58ffe99aacaabc4c51990e566c89852c27eb [file] [log] [blame]
import pathlib
from typing import List, Type
import pytest
import yaml
from hamilton.function_modifiers.adapters import resolve_adapter_class
from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.plugins.yaml_extensions import PrimitiveTypes, YAMLDataLoader, YAMLDataSaver
TEST_DATA_FOR_YAML = [
(1, "int.yaml"),
("string", "string.yaml"),
(True, "bool.yaml"),
({"key": "value"}, "test.yaml"),
([1, 2, 3], "data.yaml"),
({"nested": {"a": 1, "b": 2}}, "config.yaml"),
]
@pytest.mark.parametrize("data, file_name", TEST_DATA_FOR_YAML)
def test_yaml_loader(tmp_path: pathlib.Path, data, file_name):
path = tmp_path / pathlib.Path(file_name)
with path.open(mode="w") as f:
yaml.dump(data, f)
assert path.exists()
loader = YAMLDataLoader(path)
loaded_data = loader.load_data(type(data))
assert loaded_data[0] == data
@pytest.mark.parametrize("data, file_name", TEST_DATA_FOR_YAML)
def test_yaml_saver(tmp_path: pathlib.Path, data, file_name):
path = tmp_path / pathlib.Path(file_name)
saver = YAMLDataSaver(path)
saver.save_data(data)
assert path.exists()
with path.open("r") as f:
loaded_data = yaml.safe_load(f)
assert data == loaded_data
@pytest.mark.parametrize("data, file_name", TEST_DATA_FOR_YAML)
def test_yaml_loader_and_saver(tmp_path: pathlib.Path, data, file_name):
path = tmp_path / pathlib.Path(file_name)
saver = YAMLDataSaver(path)
saver.save_data(data)
assert path.exists()
loader = YAMLDataLoader(path)
loaded_data = loader.load_data(type(data))
assert data == loaded_data[0]
@pytest.mark.parametrize(
"type_,classes,correct_class",
[(t, [YAMLDataLoader], YAMLDataLoader) for t in PrimitiveTypes],
)
def test_resolve_correct_loader_class(
type_: Type[Type], classes: List[Type[DataLoader]], correct_class: Type[DataLoader]
):
assert resolve_adapter_class(type_, classes) == correct_class
@pytest.mark.parametrize(
"type_,classes,correct_class",
[(t, [YAMLDataSaver], YAMLDataSaver) for t in PrimitiveTypes],
)
def test_resolve_correct_saver_class(
type_: Type[Type], classes: List[Type[DataSaver]], correct_class: Type[DataLoader]
):
assert resolve_adapter_class(type_, classes) == correct_class