blob: 63de90ea00e833ab2c4517fed29d27da6fc37aa7 [file]
import importlib
import json
import sys
from typing import Any, Callable, Dict, List, Type
import pytest
from hamilton import ad_hoc_utils, base, driver, settings
from hamilton.base import DefaultAdapter
from hamilton.data_quality.base import DataValidationError, ValidationResult
from hamilton.execution import executors, grouping
from hamilton.function_modifiers import source, value
from hamilton.io.materialization import from_, to
import tests.resources.data_quality
import tests.resources.dynamic_config
import tests.resources.overrides
import tests.resources.test_for_materialization
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: driver.Driver({}, tests.resources.data_quality, adapter=DefaultAdapter())),
(
lambda: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.data_quality)
.with_remote_executor(executors.MultiThreadingExecutor(max_tasks=3))
.with_local_executor(executors.MultiThreadingExecutor(max_tasks=3))
.build()
),
(
lambda: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_modules(tests.resources.data_quality)
.build()
),
],
ids=["basic_driver", "driver_v2_multithreading", "driver_v2_synchronous"],
)
def test_data_quality_workflow_passes(driver_factory: Callable[[], driver.Driver]):
dr = driver_factory()
all_vars = dr.list_available_variables()
result = dr.execute([var.name for var in all_vars], inputs={"data_quality_should_fail": False})
dq_nodes = [
var.name
for var in all_vars
if var.tags.get("hamilton.data_quality.contains_dq_results", False)
]
assert len(dq_nodes) == 1
dq_result = result[dq_nodes[0]]
assert isinstance(dq_result, ValidationResult)
assert dq_result.passes is True
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: driver.Driver({}, tests.resources.data_quality)),
(
lambda: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.data_quality)
.with_remote_executor(executors.MultiThreadingExecutor(max_tasks=3))
.with_local_executor(executors.MultiThreadingExecutor(max_tasks=3))
.build()
),
(
lambda: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_modules(tests.resources.data_quality)
.build()
),
],
ids=["basic_driver", "driver_v2_multithreading", "driver_v2_synchronous"],
)
def test_data_quality_workflow_fails(driver_factory):
dr = driver_factory()
all_vars = dr.list_available_variables()
with pytest.raises(DataValidationError):
dr.execute([var.name for var in all_vars], inputs={"data_quality_should_fail": True})
# Adapted from https://stackoverflow.com/questions/41858147/how-to-modify-imported-source-code-on-the-fly
# This is needed to decide whether to import annotations...
def modify_and_import(module_name, package, modification_func):
spec = importlib.util.find_spec(module_name, package)
source = spec.loader.get_source(module_name)
new_source = modification_func(source)
module = importlib.util.module_from_spec(spec)
codeobj = compile(new_source, module.__spec__.origin, "exec")
exec(codeobj, module.__dict__)
sys.modules[module_name] = module
return module
@pytest.mark.parametrize(
"driver_factory,future_import_annotations",
[
(lambda modules: driver.Driver({"region": "US"}, *modules), False),
(
lambda modules: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(*modules)
.with_remote_executor(executors.MultiThreadingExecutor(max_tasks=3))
.with_local_executor(executors.MultiThreadingExecutor(max_tasks=3))
.with_config({"region": "US"})
.build(),
False,
),
(
lambda modules: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_config({"region": "US"})
.with_modules(*modules)
.build(),
False,
),
(lambda modules: driver.Driver({"region": "US"}, *modules), True),
(
lambda modules: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(*modules)
.with_remote_executor(executors.MultiThreadingExecutor(max_tasks=3))
.with_local_executor(executors.MultiThreadingExecutor(max_tasks=3))
.with_config({"region": "US"})
.build(),
True,
),
(
lambda modules: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_config({"region": "US"})
.with_modules(*modules)
.build(),
True,
),
(
lambda modules: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_grouping_strategy(grouping.GroupNodesAllAsOne())
.with_config({"region": "US"})
.with_modules(*modules)
.build(),
False,
),
(
lambda modules: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_grouping_strategy(grouping.GroupNodesByLevel())
.with_config({"region": "US"})
.with_modules(*modules)
.build(),
False,
),
(
lambda modules: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_grouping_strategy(grouping.GroupNodesIndividually())
.with_config({"region": "US"})
.with_modules(*modules)
.build(),
False,
),
],
ids=[
"basic_driver_no_future_import",
"driver_v2_multithreading_no_future_import",
"driver_v2_synchronous_no_future_import",
"basic_driver_future_import",
"driver_v2_multithreading_future_import",
"driver_v2_synchronous_future_import",
"driver_v2_group_all_as_one",
"driver_v2_group_by_level",
"driver_v2_group_nodes_individually",
],
)
def test_smoke_screen_module(driver_factory, future_import_annotations):
# Monkeypatch the env
# This tells the smoke screen module whether to use the future import
modification_func = (
lambda source: "\n".join(["from __future__ import annotations"] + source.splitlines())
if future_import_annotations
else source
)
module = modify_and_import(
"tests.resources.smoke_screen_module", tests.resources, modification_func
)
dr = driver_factory([module])
output_columns = [
"raw_acquisition_cost",
"pessimistic_net_acquisition_cost",
"neutral_net_acquisition_cost",
"optimistic_net_acquisition_cost",
"series_with_start_date_end_date",
]
res = dr.execute(
inputs={"date_range": {"start_date": "20200101", "end_date": "20220801"}},
final_vars=output_columns,
)
epsilon = 0.00001
assert abs(res["raw_acquisition_cost"].mean() - 0.393808) < epsilon
assert abs(res["pessimistic_net_acquisition_cost"].mean() - 0.420769) < epsilon
assert abs(res["neutral_net_acquisition_cost"].mean() - 0.405582) < epsilon
assert abs(res["optimistic_net_acquisition_cost"].mean() - 0.399363) < epsilon
assert res["series_with_start_date_end_date"].iloc[0] == "date_20200101_date_20220801"
_dynamic_config = {
"columns_to_sum_map": {
"ab": ["a", "b"],
"cd": ["c", "d"],
"ed": ["e", "d"],
"ae": ["a", "e"],
# You can create some crazy self-referential stuff
"abcd": ["ab", "cd"],
"abcdcd": ["abcd", "cd"],
},
settings.ENABLE_POWER_USER_MODE: True,
}
@pytest.mark.parametrize(
"driver_factory",
[
(
lambda: driver.Driver(
_dynamic_config, tests.resources.dynamic_config, adapter=DefaultAdapter()
)
),
(
lambda: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.dynamic_config)
.with_config(_dynamic_config)
.with_remote_executor(executors.MultiThreadingExecutor(max_tasks=3))
.with_local_executor(executors.MultiThreadingExecutor(max_tasks=3))
.build()
),
(
lambda: driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.dynamic_config)
.with_config(_dynamic_config)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.build()
),
],
ids=["basic_driver", "driver_v2_multithreading", "driver_v2_synchronous"],
)
def test_end_to_end_with_dynamic_config(driver_factory):
dr = driver_factory()
out = dr.execute(final_vars=list(_dynamic_config["columns_to_sum_map"].keys()))
assert out["ab"][0] == 2
assert out["cd"][0] == 2
assert out["ed"][0] == 2
assert out["ae"][0] == 2
assert out["abcd"][0] == 4
assert out["abcdcd"][0] == 6
class JoinBuilder(base.ResultMixin):
@staticmethod
def build_result(**outputs: Dict[str, Any]) -> Any:
out = {}
for output in outputs.values():
out.update(output)
return out
def output_type(self) -> Type:
return dict
def input_types(self) -> List[Type]:
return [dict]
def test_materialize_driver_call_end_to_end(tmp_path_factory):
dr = driver.Driver({}, tests.resources.test_for_materialization)
path_1 = tmp_path_factory.mktemp("home") / "test_materialize_driver_call_end_to_end_1.json"
path_2 = tmp_path_factory.mktemp("home") / "test_materialize_driver_call_end_to_end_2.json"
materialization_result, result = dr.materialize(
to.json(id="materializer_1", dependencies=["json_to_save_1"], path=path_1),
to.json(
id="materializer_2",
dependencies=["json_to_save_1", "json_to_save_2"],
path=path_2,
combine=JoinBuilder(),
),
additional_vars=["json_to_save_1", "json_to_save_2"],
)
assert result == {
"json_to_save_1": tests.resources.test_for_materialization.json_to_save_1(),
"json_to_save_2": tests.resources.test_for_materialization.json_to_save_2(),
}
assert "materializer_1" in materialization_result
assert "materializer_2" in materialization_result
with open(path_1) as f_1, open(path_2) as f_2:
assert json.load(f_1) == tests.resources.test_for_materialization.json_to_save_1()
assert json.load(f_2) == {
**tests.resources.test_for_materialization.json_to_save_2(),
**tests.resources.test_for_materialization.json_to_save_1(),
}
def test_materialize_and_loaders_end_to_end(tmp_path_factory):
def processed_data(input_data: dict) -> dict:
data = input_data.copy()
data["processed"] = True
return data
path_in = tmp_path_factory.mktemp("home") / "unprocessed_data.json"
path_out = tmp_path_factory.mktemp("home") / "processed_data.json"
with open(path_in, "w") as f:
json.dump({"processed": False}, f)
mod = ad_hoc_utils.create_temporary_module(processed_data)
dr = driver.Driver({}, mod)
materialization_result, result = dr.materialize(
from_.json(target="input_data", path=value(path_in)),
to.json(
id="materializer",
dependencies=["processed_data"],
path=source("output_path"),
combine=JoinBuilder(),
),
additional_vars=["processed_data"],
inputs={"output_path": str(path_out)},
)
assert result["processed_data"] == {"processed": True}
assert "materializer" in materialization_result
with open(path_out) as f:
assert json.load(f) == {"processed": True}
def test_driver_validate_with_overrides():
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
assert dr.execute(["c"], overrides={"b": 1})["c"] == 2
def test_driver_validate_with_overrides_2():
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
assert dr.execute(["d"], overrides={"b": 1})["d"] == 3