| import asyncio |
| from unittest import mock |
| |
| import pandas as pd |
| import pytest |
| |
| from hamilton import ad_hoc_utils, async_driver, base |
| from hamilton.lifecycle.base import ( |
| BasePostGraphConstruct, |
| BasePostGraphConstructAsync, |
| BasePostGraphExecute, |
| BasePostGraphExecuteAsync, |
| BasePostNodeExecute, |
| BasePostNodeExecuteAsync, |
| BasePreGraphExecute, |
| BasePreGraphExecuteAsync, |
| BasePreNodeExecute, |
| BasePreNodeExecuteAsync, |
| ) |
| |
| from .resources import simple_async_module |
| |
| |
| async def async_identity(n: int) -> int: |
| await asyncio.sleep(0.01) |
| return n |
| |
| |
| @pytest.mark.asyncio |
| async def test_await_dict_of_coroutines(): |
| tasks = {n: async_identity(n) for n in range(0, 10)} |
| results = await async_driver.await_dict_of_tasks(tasks) |
| assert results == {n: await async_identity(n) for n in range(0, 10)} |
| |
| |
| @pytest.mark.asyncio |
| async def test_await_dict_of_tasks(): |
| tasks = {n: asyncio.create_task(async_identity(n)) for n in range(0, 10)} |
| results = await async_driver.await_dict_of_tasks(tasks) |
| assert results == {n: await async_identity(n) for n in range(0, 10)} |
| |
| |
| # The following are not parameterized as we need to use the event loop -- fixtures will complicate this |
| @pytest.mark.asyncio |
| async def test_process_value_raw(): |
| assert await async_driver.process_value(1) == 1 |
| |
| |
| @pytest.mark.asyncio |
| async def test_process_value_coroutine(): |
| assert await async_driver.process_value(async_identity(1)) == 1 |
| |
| |
| @pytest.mark.asyncio |
| async def test_process_value_task(): |
| assert await async_driver.process_value(asyncio.create_task(async_identity(1))) == 1 |
| |
| |
| @pytest.mark.asyncio |
| async def test_driver_end_to_end(): |
| dr = async_driver.AsyncDriver({}, simple_async_module) |
| all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] |
| result = await dr.raw_execute(final_vars=all_vars, inputs={"external_input": 1}) |
| result["a"] = result["a"].to_dict() # convert to dict for comparison |
| result["b"] = result["b"].to_dict() # convert to dict for comparison |
| assert result == { |
| "a": pd.Series([1, 2, 3]).to_dict(), |
| "another_async_func": 8, |
| "async_func_with_param": 4, |
| "b": pd.Series([4, 5, 6]).to_dict(), |
| "external_input": 1, |
| "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, |
| "result_1": 9, |
| "result_2": 5, |
| "result_3": 1, |
| "result_4": 2, |
| "return_dict": {"result_3": 1, "result_4": 2}, |
| "simple_async_func": 2, |
| "simple_non_async_func": 7, |
| } |
| |
| |
| @pytest.mark.asyncio |
| @mock.patch("hamilton.telemetry.send_event_json") |
| @mock.patch("hamilton.telemetry.g_telemetry_enabled", True) |
| async def test_driver_end_to_end_telemetry(send_event_json): |
| dr = async_driver.AsyncDriver({}, simple_async_module, result_builder=base.DictResult()) |
| with mock.patch("hamilton.telemetry.g_telemetry_enabled", False): |
| # don't count this telemetry tracking invocation |
| all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] |
| result = await dr.execute(final_vars=all_vars, inputs={"external_input": 1}) |
| result["a"] = result["a"].to_dict() |
| result["b"] = result["b"].to_dict() |
| assert result == { |
| "a": pd.Series([1, 2, 3]).to_dict(), |
| "another_async_func": 8, |
| "async_func_with_param": 4, |
| "b": pd.Series([4, 5, 6]).to_dict(), |
| "external_input": 1, |
| "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, |
| "result_1": 9, |
| "result_2": 5, |
| "result_3": 1, |
| "result_4": 2, |
| "return_dict": {"result_3": 1, "result_4": 2}, |
| "simple_async_func": 2, |
| "simple_non_async_func": 7, |
| } |
| # to ensure the last telemetry invocation finishes executing |
| # get all tasks -- and the current task, and await all others. |
| tasks = asyncio.all_tasks() |
| current_task = asyncio.current_task() |
| await asyncio.gather(*[t for t in tasks if t != current_task]) |
| assert send_event_json.called |
| assert len(send_event_json.call_args_list) == 2 |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_driver_end_to_end_async_lifecycle_methods(): |
| tracked_calls = [] |
| |
| class AsyncTrackingAdapter( |
| BasePostGraphConstructAsync, |
| BasePreGraphExecuteAsync, |
| BasePostGraphExecuteAsync, |
| BasePreNodeExecuteAsync, |
| BasePostNodeExecuteAsync, |
| ): |
| def __init__(self, calls: list, pause_time: float = 0.01): |
| self.pause_time = pause_time |
| self.calls = calls |
| |
| async def _pause(self): |
| return await asyncio.sleep(self.pause_time) |
| |
| async def pre_graph_execute(self, **kwargs): |
| await self._pause() |
| self.calls.append(("pre_graph_execute", kwargs)) |
| |
| async def post_graph_execute(self, **kwargs): |
| await self._pause() |
| self.calls.append(("post_graph_execute", kwargs)) |
| |
| async def pre_node_execute(self, **kwargs): |
| await self._pause() |
| self.calls.append(("pre_node_execute", kwargs)) |
| |
| async def post_node_execute(self, **kwargs): |
| await self._pause() |
| self.calls.append(("post_node_execute", kwargs)) |
| |
| async def post_graph_construct(self, **kwargs): |
| await self._pause() |
| self.calls.append(("post_graph_construct", kwargs)) |
| |
| adapter = AsyncTrackingAdapter(tracked_calls) |
| |
| dr = await async_driver.AsyncDriver( |
| {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] |
| ).ainit() |
| all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] |
| result = await dr.execute(final_vars=all_vars, inputs={"external_input": 1}) |
| hooks_called = [call[0] for call in tracked_calls] |
| assert set(hooks_called) == { |
| "pre_graph_execute", |
| "post_graph_execute", |
| "pre_node_execute", |
| "post_node_execute", |
| "post_graph_construct", |
| } |
| result["a"] = result["a"].to_dict() |
| result["b"] = result["b"].to_dict() |
| assert result == { |
| "a": pd.Series([1, 2, 3]).to_dict(), |
| "another_async_func": 8, |
| "async_func_with_param": 4, |
| "b": pd.Series([4, 5, 6]).to_dict(), |
| "external_input": 1, |
| "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, |
| "result_1": 9, |
| "result_2": 5, |
| "result_3": 1, |
| "result_4": 2, |
| "return_dict": {"result_3": 1, "result_4": 2}, |
| "simple_async_func": 2, |
| "simple_non_async_func": 7, |
| } |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_driver_end_to_end_sync_lifecycle_methods(): |
| tracked_calls = [] |
| |
| class AsyncTrackingAdapter( |
| BasePostGraphConstruct, |
| BasePreGraphExecute, |
| BasePostGraphExecute, |
| BasePreNodeExecute, |
| BasePostNodeExecute, |
| ): |
| def __init__(self, calls: list, pause_time: float = 0.01): |
| self.pause_time = pause_time |
| self.calls = calls |
| |
| def pre_graph_execute(self, **kwargs): |
| self.calls.append(("pre_graph_execute", kwargs)) |
| |
| def post_graph_execute(self, **kwargs): |
| self.calls.append(("post_graph_execute", kwargs)) |
| |
| def pre_node_execute(self, **kwargs): |
| self.calls.append(("pre_node_execute", kwargs)) |
| |
| def post_node_execute(self, **kwargs): |
| self.calls.append(("post_node_execute", kwargs)) |
| |
| def post_graph_construct(self, **kwargs): |
| self.calls.append(("post_graph_construct", kwargs)) |
| |
| adapter = AsyncTrackingAdapter(tracked_calls) |
| |
| dr = await async_driver.AsyncDriver( |
| {}, simple_async_module, result_builder=base.DictResult(), adapters=[adapter] |
| ).ainit() |
| all_vars = [var.name for var in dr.list_available_variables() if var.name != "return_df"] |
| result = await dr.execute(final_vars=all_vars, inputs={"external_input": 1}) |
| hooks_called = [call[0] for call in tracked_calls] |
| assert set(hooks_called) == { |
| "pre_graph_execute", |
| "post_graph_execute", |
| "pre_node_execute", |
| "post_node_execute", |
| "post_graph_construct", |
| } |
| result["a"] = result["a"].to_dict() |
| result["b"] = result["b"].to_dict() |
| assert result == { |
| "a": pd.Series([1, 2, 3]).to_dict(), |
| "another_async_func": 8, |
| "async_func_with_param": 4, |
| "b": pd.Series([4, 5, 6]).to_dict(), |
| "external_input": 1, |
| "non_async_func_with_decorator": {"result_1": 9, "result_2": 5}, |
| "result_1": 9, |
| "result_2": 5, |
| "result_3": 1, |
| "result_4": 2, |
| "return_dict": {"result_3": 1, "result_4": 2}, |
| "simple_async_func": 2, |
| "simple_non_async_func": 7, |
| } |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_driver_overrides(): |
| # This specifically tests that raw execute does not return |
| # and does not compute the non-returned items |
| this_list_should_be_empty = [] |
| |
| async def never_run() -> int: |
| this_list_should_be_empty.append("never_run") |
| await asyncio.sleep(0.01) |
| return 1 |
| |
| async def overridden(never_run: int) -> int: |
| this_list_should_be_empty.append("overridden") |
| await asyncio.sleep(0.01) |
| return never_run + 1 |
| |
| async def will_be_run(overridden: int) -> int: |
| await asyncio.sleep(0.01) |
| return 1 + overridden |
| |
| mod = ad_hoc_utils.create_temporary_module(never_run, will_be_run, overridden) |
| dr = await async_driver.Builder().with_modules(mod).build() |
| # Testing execute as it's a part of the contract and will remain so |
| # raw_execute is closer to the error, but don't want to test it instead |
| res = await dr.execute(final_vars=["will_be_run"], overrides={"overridden": 3}) |
| assert len(this_list_should_be_empty) == 0 |
| assert res == {"will_be_run": 4} |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_builder_result_builder_custom(): |
| def foo() -> int: |
| return 1 |
| |
| mod = ad_hoc_utils.create_temporary_module(foo) |
| result_builder = base.PandasDataFrameResult() |
| dr = await async_driver.Builder().with_adapters(result_builder).with_modules(mod).build() |
| assert isinstance(await dr.execute(final_vars=["foo"]), pd.DataFrame) |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_builder_result_builder_default(): |
| def foo() -> int: |
| return 1 |
| |
| mod = ad_hoc_utils.create_temporary_module(foo) |
| dr = await async_driver.Builder().with_modules(mod).build() |
| assert (await dr.execute(final_vars=["foo"])) == {"foo": 1} |
| # builder.result_builder = result_builder |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_builder_without_init(): |
| def foo() -> int: |
| return 1 |
| |
| class TestLifecycleMethod(BasePostGraphConstructAsync): |
| def __init__(self): |
| self.ran = False |
| |
| async def post_graph_construct(self, **kwargs): |
| await asyncio.sleep(0.01) |
| self.ran = True |
| |
| mod = ad_hoc_utils.create_temporary_module(foo) |
| hook = TestLifecycleMethod() |
| # non-async |
| async_driver.Builder().with_modules(mod).with_adapters(hook).build_without_init() |
| assert not hook.ran |
| await async_driver.Builder().with_modules(mod).with_adapters(hook).build() |
| assert hook.ran |
| |
| # builder.result_builder = result_builder |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_builder_allow_module_overrides(): |
| def foo() -> int: |
| return 1 |
| |
| mod1 = ad_hoc_utils.create_temporary_module(foo) |
| |
| def foo() -> int: |
| return 2 |
| |
| mod2 = ad_hoc_utils.create_temporary_module(foo) |
| |
| # Should raise without .allow_module_overrides() |
| with pytest.raises(ValueError) as e: |
| await async_driver.Builder().with_modules(mod1, mod2).build() |
| |
| assert "Cannot define function foo more than once." in str(e.value) |
| |
| # build_without_init should also raise |
| with pytest.raises(ValueError) as e: |
| async_driver.Builder().with_modules(mod1, mod2).build_without_init() |
| |
| assert "Cannot define function foo more than once." in str(e.value) |
| |
| # Should not raise with .allow_module_overrides() |
| dr = await async_driver.Builder().with_modules(mod1, mod2).allow_module_overrides().build() |
| assert (await dr.execute(final_vars=["foo"])) == {"foo": 2} |
| |
| # Same with build_without_init |
| dr = ( |
| async_driver.Builder() |
| .with_modules(mod1, mod2) |
| .allow_module_overrides() |
| .build_without_init() |
| ) |
| assert (await dr.execute(final_vars=["foo"])) == {"foo": 2} |