blob: fd74a08f4ac06d39c632eaada2287972b78068c5 [file] [log] [blame]
from unittest import mock
import pandas as pd
import pytest
from hamilton import base, node
from hamilton.driver import (
Builder,
Driver,
InvalidExecutorException,
TaskBasedGraphExecutor,
Variable,
)
from hamilton.execution import executors
from hamilton.io.materialization import to
import tests.resources.cyclic_functions
import tests.resources.dummy_functions
import tests.resources.dynamic_parallelism.parallel_linear_basic
import tests.resources.tagging
import tests.resources.test_default_args
import tests.resources.very_simple_dag
"""This file tests driver capabilities.
Anything involving execution is tested for multiple executors/driver configuration.
Anything not involving execution is tested for just the single driver configuration.
TODO -- move any execution tests to tests the graph executor capabilities on their own.
"""
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({"a": 1})),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_config({"a": 1})
.build()
),
],
)
def test_driver_validate_input_types(driver_factory):
dr = driver_factory()
results = dr.raw_execute(["a"])
assert results == {"a": 1}
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({}, tests.resources.very_simple_dag)),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.very_simple_dag)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.build()
),
],
)
def test_driver_validate_runtime_input_types(driver_factory):
dr = driver_factory()
results = dr.raw_execute(["b"], inputs={"a": 1})
assert results == {"b": 1}
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({}, tests.resources.cyclic_functions)),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.cyclic_functions)
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.build()
),
],
)
def test_driver_has_cycles_true(driver_factory):
"""Tests that we don't break when detecting cycles from the driver."""
dr = driver_factory()
assert dr.has_cycles(["C"])
# This is possible -- but we don't want to officially support it. Here for documentation purposes.
# def test_driver_cycles_execute_override():
# """Tests that we short circuit a cycle by passing in overrides."""
# dr = Driver({}, tests.resources.cyclic_functions, adapter=base.DefaultAdapter())
# result = dr.execute(['C'], overrides={'D': 1}, inputs={'b': 2, 'c': 2})
# assert result['C'] == 34
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({}, tests.resources.cyclic_functions)),
# TODO -- fix erroring out when we try to run a driver with cycles
# should display a better error
# (lambda: Builder()
# .enable_parallelizable_type(allow_experimental_mode=True)
# .with_modules(tests.resources.cyclic_functions)
# .with_remote_executor(executors.SynchronousLocalTaskExecutor())
# .with_adapter(base.DefaultAdapter())
# .build())
],
)
def test_driver_cycles_execute_recursion_error(driver_factory):
"""Tests that we throw a recursion error when we try to execute over a DAG that isn't a DAG."""
dr = driver_factory()
with pytest.raises(RecursionError):
dr.execute(["C"], inputs={"b": 2, "c": 2})
def test_driver_variables_exposes_tags():
dr = Driver({}, tests.resources.tagging)
tags = {var.name: var.tags for var in dr.list_available_variables()}
assert tags["a"] == {"module": "tests.resources.tagging", "test": "a"}
assert tags["b"] == {"module": "tests.resources.tagging", "test": "b_c"}
assert tags["c"] == {"module": "tests.resources.tagging", "test": "b_c"}
assert tags["d"] == {"module": "tests.resources.tagging", "test_list": ["us", "uk"]}
@pytest.mark.parametrize(
"filter,expected",
[
(None, {"a", "b_c", "b", "c", "d"}), # no filter
({}, {"a", "b_c", "b", "c", "d"}), # empty filter
({"test": "b_c"}, {"b", "c"}),
({"test": None}, {"a", "b", "c"}),
({"module": "tests.resources.tagging"}, {"a", "b_c", "b", "c", "d"}),
({"test_list": "us"}, {"d"}),
({"test_list": "uk"}, {"d"}),
({"test_list": ["uk"]}, {"d"}),
({"module": "tests.resources.tagging", "test": "b_c"}, {"b", "c"}),
({"test_list": ["nz", "uk"]}, {"d"}),
({"test_list": ["us", "uk"]}, {"d"}),
({"test_list": ["uk", "us"]}, {"d"}),
({"test": ["b_c"]}, {"b", "c"}),
({"test": ["b_c", "foo"]}, {"b", "c"}),
],
ids=[
"filter with None passed",
"filter with empty filter",
"filter by single tag with extract decorator",
"filter with None value",
"filter with specific value",
"filter tag with list values - value 1",
"filter tag with list values - value 2",
"filter tag with list values - query is single node list",
"filter with two filter clauses",
"filter with with list values not exact OR interpretation",
"filter with with list values exact",
"filter with with list values exact order invariant",
"filter with with list values edge case one item match",
"filter with with list values OR interpretation",
],
)
def test_driver_variables_filters_tags(filter, expected):
dr = Driver({}, tests.resources.tagging)
actual = {var.name for var in dr.list_available_variables(tag_filter=filter)}
assert actual == expected
def test_driver_variables_filters_tags_error():
dr = Driver({}, tests.resources.tagging)
with pytest.raises(ValueError):
# non string value is not allowed
dr.list_available_variables(tag_filter={"test": 1234})
with pytest.raises(ValueError):
# empty list shouldn't be allowed
dr.list_available_variables(tag_filter={"test": []})
def test_driver_variables_external_input():
dr = Driver({}, tests.resources.very_simple_dag)
input_types = {var.name: var.is_external_input for var in dr.list_available_variables()}
assert input_types["a"] is True
assert input_types["b"] is False
def test_driver_variables_exposes_original_function():
dr = Driver({}, tests.resources.very_simple_dag)
originating_functions = {
var.name: var.originating_functions for var in dr.list_available_variables()
}
assert originating_functions["b"] == (tests.resources.very_simple_dag.b,)
assert originating_functions["a"] == (tests.resources.very_simple_dag.b,) # a is an input
@mock.patch("hamilton.telemetry.send_event_json")
def test_capture_constructor_telemetry_disabled(send_event_json):
"""Tests that we don't do anything if telemetry is disabled."""
send_event_json.return_value = ""
Driver({}, tests.resources.tagging) # this will exercise things underneath.
assert send_event_json.called is False
@mock.patch("hamilton.telemetry.get_adapter_name")
@mock.patch("hamilton.telemetry.send_event_json")
@mock.patch("hamilton.telemetry.g_telemetry_enabled", True)
def test_capture_constructor_telemetry_error(send_event_json, get_adapter_name):
"""Tests that we don't error if an exception occurs"""
get_adapter_name.side_effect = ValueError("TELEMETRY ERROR")
Driver({}, tests.resources.tagging) # this will exercise things underneath.
assert send_event_json.called is False
@mock.patch("hamilton.telemetry.send_event_json")
@mock.patch("hamilton.telemetry.g_telemetry_enabled", True)
def test_capture_constructor_telemetry_none_values(send_event_json):
"""Tests that we don't error if there are none values"""
Driver({}, None, None) # this will exercise things underneath.
assert send_event_json.called is True
@mock.patch("hamilton.telemetry.send_event_json")
@mock.patch("hamilton.telemetry.g_telemetry_enabled", True)
def test_capture_constructor_telemetry(send_event_json):
"""Tests that we send an event if we could. Validates deterministic parts."""
Driver({}, tests.resources.very_simple_dag)
# assert send_event_json.called is True
assert len(send_event_json.call_args_list) == 1 # only called once
# check contents of what it was called with:
send_event_json_call = send_event_json.call_args_list[0]
actual_event_dict = send_event_json_call[0][0]
assert actual_event_dict["api_key"] == "phc_mZg8bkn3yvMxqvZKRlMlxjekFU5DFDdcdAsijJ2EH5e"
assert actual_event_dict["event"] == "os_hamilton_run_start"
# validate schema
expected_properties = {
"os_type",
"os_version",
"python_version",
"distinct_id",
"hamilton_version",
"telemetry_version",
"number_of_nodes",
"number_of_modules",
"number_of_config_items",
"decorators_used",
"graph_adapter_used",
"result_builder_used",
"driver_run_id",
"error",
"graph_executor_class",
"lifecycle_adapters_used",
}
actual_properties = actual_event_dict["properties"]
assert set(actual_properties.keys()) == expected_properties
# validate static parts
assert actual_properties["error"] is None
assert actual_properties["number_of_nodes"] == 2 # b, and input a
assert actual_properties["number_of_modules"] == 1
assert actual_properties["number_of_config_items"] == 0
assert actual_properties["number_of_config_items"] == 0
assert actual_properties["graph_adapter_used"] == "deprecated -- see lifecycle_adapters_used"
assert actual_properties["result_builder_used"] == "hamilton.base.PandasDataFrameResult"
assert actual_properties["lifecycle_adapters_used"] == ["hamilton.base.PandasDataFrameResult"]
@mock.patch("hamilton.telemetry.send_event_json")
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({}, tests.resources.very_simple_dag)),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.very_simple_dag)
.with_adapter(base.SimplePythonGraphAdapter(base.PandasDataFrameResult()))
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.build()
),
],
)
def test_capture_execute_telemetry_disabled(send_event_json, driver_factory):
"""Tests that we don't do anything if telemetry is disabled."""
dr = driver_factory()
results = dr.execute(["b"], inputs={"a": 1})
expected = pd.DataFrame([{"b": 1}])
pd.testing.assert_frame_equal(results, expected)
assert send_event_json.called is False
@mock.patch("hamilton.telemetry.send_event_json")
@mock.patch("hamilton.telemetry.g_telemetry_enabled", True)
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({}, tests.resources.very_simple_dag)),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.very_simple_dag)
.with_adapter(base.SimplePythonGraphAdapter(base.PandasDataFrameResult()))
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.build()
),
],
)
def test_capture_execute_telemetry_error(send_event_json, driver_factory):
"""Tests that we don't error if an exception occurs"""
send_event_json.side_effect = [None, ValueError("FAKE ERROR"), None]
dr = driver_factory()
results = dr.execute(["b"], inputs={"a": 1})
expected = pd.DataFrame([{"b": 1}])
pd.testing.assert_frame_equal(results, expected)
assert send_event_json.called is True
assert len(send_event_json.call_args_list) == 2
@mock.patch("hamilton.telemetry.send_event_json")
@mock.patch("hamilton.telemetry.g_telemetry_enabled", True)
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({}, tests.resources.very_simple_dag)),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.very_simple_dag)
.with_adapter(base.SimplePythonGraphAdapter(base.PandasDataFrameResult()))
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.build()
),
],
)
def test_capture_execute_telemetry(send_event_json, driver_factory):
"""Happy path with values passed."""
dr = driver_factory()
results = dr.execute(["b"], inputs={"a": 1}, overrides={"b": 2})
expected = pd.DataFrame([{"b": 2}])
pd.testing.assert_frame_equal(results, expected)
assert send_event_json.called is True
assert len(send_event_json.call_args_list) == 2
@mock.patch("hamilton.telemetry.send_event_json")
@mock.patch("hamilton.telemetry.g_telemetry_enabled", True)
@pytest.mark.parametrize(
"driver_factory",
[
(lambda: Driver({"a": 1}, tests.resources.very_simple_dag)),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.very_simple_dag)
.with_adapter(base.SimplePythonGraphAdapter(base.PandasDataFrameResult()))
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_config({"a": 1})
.build()
),
],
)
def test_capture_execute_telemetry_none_values(send_event_json, driver_factory):
"""Happy path with none values."""
dr = driver_factory()
results = dr.execute(["b"])
expected = pd.DataFrame([{"b": 1}])
pd.testing.assert_frame_equal(results, expected)
assert len(send_event_json.call_args_list) == 2
@pytest.mark.parametrize(
"driver_factory",
[
(
lambda: Driver(
{"required": 1},
tests.resources.test_default_args,
adapter=base.DefaultAdapter(),
)
),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.test_default_args)
.with_adapter(base.DefaultAdapter())
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_config({"required": 1})
.build()
),
],
)
def test_node_is_required_by_anything(driver_factory):
"""Tests that default args are correctly interpreted.
Specifically, if it's not in the execution path then things should
just work. Here I'm being lazy and rather than specifically testing
_node_is_required_by_anything() directly, I'm doing it via
execute(), which calls it via validate_inputs().
To understand what's going on see the functions in `test_default_args`.
"""
dr = driver_factory()
# D is not in the execution path, but requires defaults_to_zero
# so this should work.
results = dr.execute(["C"])
assert results["C"] == 2
with pytest.raises(ValueError):
# D is now in the execution path, but requires defaults_to_zero
# this should error
dr.execute(["D"])
@pytest.mark.parametrize(
"driver_factory",
[
(
lambda: Driver(
{"required": 1},
tests.resources.test_default_args,
adapter=base.DefaultAdapter(),
)
),
(
lambda: Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.test_default_args)
.with_adapter(base.DefaultAdapter())
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.with_config({"required": 1})
.build()
),
],
)
def test_using_callables_to_execute(driver_factory):
"""Test that you can pass a function reference and it will work fine."""
dr = driver_factory()
results = dr.execute(
[tests.resources.test_default_args.C, tests.resources.test_default_args.B, "A"]
)
assert results["C"] == 2
assert results["B"] == 1
assert results["A"] == 1
with pytest.raises(ValueError):
dr.execute([tests.resources.cyclic_functions.B])
def test_create_final_vars():
"""Tests that the final vars are created correctly."""
dr = Driver({"required": 1}, tests.resources.test_default_args)
D_node = dr.graph.nodes["D"]
actual = dr._create_final_vars(
[
"C",
tests.resources.test_default_args.B,
tests.resources.test_default_args.A,
Variable.from_node(D_node),
]
)
expected = ["C", "B", "A", "D"]
assert actual == expected
def test_create_final_vars_errors():
"""Tests that we catch functions pointed to in modules that aren't part of the DAG."""
dr = Driver({"required": 1}, tests.resources.test_default_args)
with pytest.raises(ValueError):
dr._create_final_vars(
["C", tests.resources.cyclic_functions.A, tests.resources.cyclic_functions.B]
)
def test_v2_driver_builder():
dr = (
Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_adapter(base.DefaultAdapter())
.with_modules(tests.resources.very_simple_dag)
.build()
)
assert isinstance(dr.graph_executor, TaskBasedGraphExecutor)
assert list(dr.graph_modules) == [tests.resources.very_simple_dag]
def test_executor_validates_happy_default_executor():
dr = Driver({}, tests.resources.very_simple_dag)
nodes, user_nodes = dr.graph.get_upstream_nodes(["b"])
dr.graph_executor.validate(nodes | user_nodes)
def test_executor_validates_sad_default_executor():
dr = Driver({}, tests.resources.dynamic_parallelism.parallel_linear_basic)
nodes, user_nodes = dr.graph.get_upstream_nodes(["final"])
with pytest.raises(InvalidExecutorException):
dr.graph_executor.validate(nodes | user_nodes)
def test_executor_validates_happy_parallel_executor():
dr = (
Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(tests.resources.dynamic_parallelism.parallel_linear_basic)
.build()
)
nodes, user_nodes = dr.graph.get_upstream_nodes(["final"])
dr.graph_executor.validate(nodes | user_nodes)
def test_builder_defaults_to_dict_result():
dr = Builder().with_modules(tests.resources.dummy_functions).build()
result = dr.execute(["C"], inputs={"b": 1, "c": 1})
assert result == {"C": 4}
def test_materialize_checks_required_input(tmp_path):
dr = Builder().with_modules(tests.resources.dummy_functions).build()
with pytest.raises(ValueError):
dr.materialize(additional_vars=["C"], inputs={"c": 1})
with pytest.raises(ValueError):
dr.materialize(
to.pickle(id="1", path=f"{tmp_path}/foo.pkl", dependencies=["C"]), inputs={"c": 1}
)
def test_validate_execution_happy():
dr = Builder().with_modules(tests.resources.very_simple_dag).build()
dr.validate_execution(["b"], inputs={"a": 1})
def test_validate_execution_sad():
dr = Builder().with_modules(tests.resources.very_simple_dag).build()
with pytest.raises(ValueError):
dr.validate_execution(["b"], inputs={})
def test_validate_materialization_happy(tmp_path):
dr = Builder().with_modules(tests.resources.very_simple_dag).build()
dr.validate_materialization(
to.pickle(id="1", path=f"{tmp_path}/foo.pkl", dependencies=["b"]), inputs={"a": 1}
)
def test_validate_materialization_sad(tmp_path):
dr = Builder().with_modules(tests.resources.very_simple_dag).build()
with pytest.raises(ValueError):
dr.validate_materialization(
# c does not exist
# no inputs either
to.pickle(id="1", path=f"{tmp_path}/foo.pkl", dependencies=["c"]),
inputs={},
)
def test_variable_from_node():
# Quick test for creating variables from nodes --
# this is simple but its nice to have
def func_to_test(a: int) -> int:
"""This is a doctstring"""
return a + 1
n = node.Node.from_fn(func_to_test)
v = Variable.from_node(n)
assert v.name == n.name
assert v.type == n.type
assert v.tags == n.tags
assert v.documentation == n.documentation == "This is a doctstring"
assert v.originating_functions == n.originating_functions