| from unittest import mock |
| |
| import pandas as pd |
| import pytest |
| |
| import tests.resources.cyclic_functions |
| import tests.resources.tagging |
| import tests.resources.test_default_args |
| import tests.resources.very_simple_dag |
| from hamilton import base |
| from hamilton.driver import Driver |
| |
| |
| def test_driver_validate_input_types(): |
| dr = Driver({"a": 1}) |
| results = dr.raw_execute(["a"]) |
| assert results == {"a": 1} |
| |
| |
| def test_driver_validate_runtime_input_types(): |
| dr = Driver({}, tests.resources.very_simple_dag) |
| results = dr.raw_execute(["b"], inputs={"a": 1}) |
| assert results == {"b": 1} |
| |
| |
| def test_driver_has_cycles_true(): |
| """Tests that we don't break when detecting cycles from the driver.""" |
| dr = Driver({}, tests.resources.cyclic_functions) |
| assert dr.has_cycles(["C"]) |
| |
| |
| # This is possible -- but we don't want to officially support it. Here for documentation purposes. |
| # def test_driver_cycles_execute_override(): |
| # """Tests that we short circuit a cycle by passing in overrides.""" |
| # dr = Driver({}, tests.resources.cyclic_functions, adapter=base.SimplePythonGraphAdapter(base.DictResult())) |
| # result = dr.execute(['C'], overrides={'D': 1}, inputs={'b': 2, 'c': 2}) |
| # assert result['C'] == 34 |
| |
| |
| def test_driver_cycles_execute_recursion_error(): |
| """Tests that we throw a recursion error when we try to execute over a DAG that isn't a DAG.""" |
| dr = Driver( |
| {}, |
| tests.resources.cyclic_functions, |
| adapter=base.SimplePythonGraphAdapter(base.DictResult()), |
| ) |
| with pytest.raises(RecursionError): |
| dr.execute(["C"], inputs={"b": 2, "c": 2}) |
| |
| |
| def test_driver_variables(): |
| dr = Driver({}, tests.resources.tagging) |
| tags = {var.name: var.tags for var in dr.list_available_variables()} |
| assert tags["a"] == {"module": "tests.resources.tagging", "test": "a"} |
| assert tags["b"] == {"module": "tests.resources.tagging", "test": "b_c"} |
| assert tags["c"] == {"module": "tests.resources.tagging", "test": "b_c"} |
| assert tags["d"] == {"module": "tests.resources.tagging"} |
| |
| |
| @mock.patch("hamilton.telemetry.send_event_json") |
| def test_capture_constructor_telemetry_disabled(send_event_json): |
| """Tests that we don't do anything if telemetry is disabled.""" |
| send_event_json.return_value = "" |
| Driver({}, tests.resources.tagging) # this will exercise things underneath. |
| assert send_event_json.called is False |
| |
| |
| @mock.patch("hamilton.telemetry.get_adapter_name") |
| @mock.patch("hamilton.telemetry.send_event_json") |
| @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) |
| def test_capture_constructor_telemetry_error(send_event_json, get_adapter_name): |
| """Tests that we don't error if an exception occurs""" |
| get_adapter_name.side_effect = ValueError("TELEMETRY ERROR") |
| Driver({}, tests.resources.tagging) # this will exercise things underneath. |
| assert send_event_json.called is False |
| |
| |
| @mock.patch("hamilton.telemetry.send_event_json") |
| @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) |
| def test_capture_constructor_telemetry_none_values(send_event_json): |
| """Tests that we don't error if there are none values""" |
| Driver({}, None, None) # this will exercise things underneath. |
| assert send_event_json.called is True |
| |
| |
| @mock.patch("hamilton.telemetry.send_event_json") |
| @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) |
| def test_capture_constructor_telemetry(send_event_json): |
| """Tests that we send an event if we could. Validates deterministic parts.""" |
| Driver({}, tests.resources.very_simple_dag) |
| # assert send_event_json.called is True |
| assert len(send_event_json.call_args_list) == 1 # only called once |
| # check contents of what it was called with: |
| send_event_json_call = send_event_json.call_args_list[0] |
| actual_event_dict = send_event_json_call[0][0] |
| assert actual_event_dict["api_key"] == "phc_mZg8bkn3yvMxqvZKRlMlxjekFU5DFDdcdAsijJ2EH5e" |
| assert actual_event_dict["event"] == "os_hamilton_run_start" |
| # validate schema |
| expected_properites = { |
| "os_type", |
| "os_version", |
| "python_version", |
| "distinct_id", |
| "hamilton_version", |
| "telemetry_version", |
| "number_of_nodes", |
| "number_of_modules", |
| "number_of_config_items", |
| "decorators_used", |
| "graph_adapter_used", |
| "result_builder_used", |
| "driver_run_id", |
| "error", |
| } |
| actual_properties = actual_event_dict["properties"] |
| assert set(actual_properties.keys()) == expected_properites |
| # validate static parts |
| assert actual_properties["error"] is None |
| assert actual_properties["number_of_nodes"] == 2 # b, and input a |
| assert actual_properties["number_of_modules"] == 1 |
| assert actual_properties["number_of_config_items"] == 0 |
| assert actual_properties["number_of_config_items"] == 0 |
| assert ( |
| actual_properties["graph_adapter_used"] == "hamilton.base.SimplePythonDataFrameGraphAdapter" |
| ) |
| assert actual_properties["result_builder_used"] == "hamilton.base.PandasDataFrameResult" |
| |
| |
| @mock.patch("hamilton.telemetry.send_event_json") |
| def test_capture_execute_telemetry_disabled(send_event_json): |
| """Tests that we don't do anything if telemetry is disabled.""" |
| dr = Driver({}, tests.resources.very_simple_dag) |
| results = dr.execute(["b"], inputs={"a": 1}) |
| expected = pd.DataFrame([{"b": 1}]) |
| pd.testing.assert_frame_equal(results, expected) |
| assert send_event_json.called is False |
| |
| |
| @mock.patch("hamilton.telemetry.send_event_json") |
| @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) |
| def test_capture_execute_telemetry_error(send_event_json): |
| """Tests that we don't error if an exception occurs""" |
| send_event_json.side_effect = [None, ValueError("FAKE ERROR")] |
| dr = Driver({}, tests.resources.very_simple_dag) |
| results = dr.execute(["b"], inputs={"a": 1}) |
| expected = pd.DataFrame([{"b": 1}]) |
| pd.testing.assert_frame_equal(results, expected) |
| assert send_event_json.called is True |
| assert len(send_event_json.call_args_list) == 2 |
| |
| |
| @mock.patch("hamilton.telemetry.send_event_json") |
| @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) |
| def test_capture_execute_telemetry(send_event_json): |
| """Happy path with values passed.""" |
| dr = Driver({}, tests.resources.very_simple_dag) |
| results = dr.execute(["b"], inputs={"a": 1}, overrides={"b": 2}) |
| expected = pd.DataFrame([{"b": 2}]) |
| pd.testing.assert_frame_equal(results, expected) |
| assert send_event_json.called is True |
| assert len(send_event_json.call_args_list) == 2 |
| |
| |
| @mock.patch("hamilton.telemetry.send_event_json") |
| @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) |
| def test_capture_execute_telemetry_none_values(send_event_json): |
| """Happy path with none values.""" |
| dr = Driver({"a": 1}, tests.resources.very_simple_dag) |
| results = dr.execute(["b"]) |
| expected = pd.DataFrame([{"b": 1}]) |
| pd.testing.assert_frame_equal(results, expected) |
| assert len(send_event_json.call_args_list) == 2 |
| |
| |
| def test__node_is_required_by_anything(): |
| """Tests that default args are correctly interpreted. |
| |
| Specifically, if it's not in the execution path then things should |
| just work. Here I'm being lazy and rather than specifically testing |
| _node_is_required_by_anything() directly, I'm doing it via |
| execute(), which calls it via validate_inputs(). |
| |
| To understand what's going on see the functions in `test_default_args`. |
| """ |
| dr = Driver({"required": 1}, tests.resources.test_default_args) |
| # D is not in the execution path, but requires defaults_to_zero |
| # so this should work. |
| results = dr.execute(["C"]) |
| pd.testing.assert_series_equal(results["C"], pd.Series([2], name="C")) |
| with pytest.raises(ValueError): |
| # D is now in the execution path, but requires defaults_to_zero |
| # this should error |
| dr.execute(["D"]) |
| |
| |
| def test_using_callables_to_execute(): |
| """Test that you can pass a function reference and it will work fine.""" |
| dr = Driver({"required": 1}, tests.resources.test_default_args) |
| results = dr.execute( |
| [tests.resources.test_default_args.C, tests.resources.test_default_args.B, "A"] |
| ) |
| pd.testing.assert_series_equal(results["C"], pd.Series([2], name="C")) |
| pd.testing.assert_series_equal(results["B"], pd.Series([1], name="B")) |
| pd.testing.assert_series_equal(results["A"], pd.Series([1], name="A")) |
| with pytest.raises(ValueError): |
| dr.execute([tests.resources.cyclic_functions.B]) |
| |
| |
| def test__create_final_vars(): |
| """Tests that the final vars are created correctly.""" |
| dr = Driver({"required": 1}, tests.resources.test_default_args) |
| actual = dr._create_final_vars( |
| ["C", tests.resources.test_default_args.B, tests.resources.test_default_args.A] |
| ) |
| expected = ["C", "B", "A"] |
| assert actual == expected |
| |
| |
| def test__create_final_vars_errors(): |
| """Tests that we catch functions pointed to in modules that aren't part of the DAG.""" |
| dr = Driver({"required": 1}, tests.resources.test_default_args) |
| with pytest.raises(ValueError): |
| dr._create_final_vars( |
| ["C", tests.resources.cyclic_functions.A, tests.resources.cyclic_functions.B] |
| ) |