blob: 55ca44b0893d48e95ed4f9853fe3a274a20dd2fd [file] [log] [blame]
import os.path
import pytest
from hamilton_sdk import adapters
from hamilton import driver
import tests.resources.basic_dag_with_config
import tests.resources.parallel_dag
import tests.resources.parallel_dag_error
from tests import test_tracking
adapter_kwargs = dict(
project_id=19319,
api_key="l-PlUq02JLQR6rAvO4x7VTttNTtprj1Tz5zBZ0ARpQ4olb8TK4hlgY2pennFhvsR1DxpYMQ-TLm0JknXVn7y9A",
username="stefan@dagworks.io",
tags={"env": "dev", "status": "development"},
client_factory=test_tracking.MockHamiltonClient,
)
def test_adapters():
kwargs = adapter_kwargs | dict(
dag_name="test_dag",
)
lifecycle_adapters = [adapters.HamiltonTracker(**kwargs)]
dr = (
driver.Builder()
.with_modules(tests.resources.basic_dag_with_config)
.with_config({"foo": "baz"})
.with_adapters(*lifecycle_adapters)
.build()
)
result = dr.execute(final_vars=["a", "b", "c"], inputs={"a": 1})
assert result == {"a": 1, "b": 3, "c": 6}
# def test_async():
# # TODO: complete Async
# kwargs = adapter_kwargs | dict(
# dag_name="async_test_dag",
# )
# [adapters.AsyncHamiltonAdapter(**kwargs)]
def test_parallel_ray():
"""Tests ray works without sampling.
Doesn't actually check the client - go do that in the UI."""
import ray
from hamilton.plugins import h_ray
kwargs = adapter_kwargs | dict(dag_name="parallel_test_dag", tags={"sampling_rate": "None"})
lifecycle_adapters = [adapters.HamiltonTracker(**kwargs)]
remote_executor = h_ray.RayTaskExecutor(None)
# remote_executor = executors.SynchronousLocalTaskExecutor()
shutdown = ray.shutdown
dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(remote_executor) # We only need to specify remote executor
# The local executor just runs it synchronously
.with_modules(tests.resources.parallel_dag)
.with_adapters(*lifecycle_adapters)
.build()
)
data_dir = os.path.join(os.path.dirname(__file__), "resources", "data")
result = dr.execute(final_vars=["statistics_by_city"], inputs={"data_dir": data_dir})[
"statistics_by_city"
]
print(result)
if shutdown:
shutdown()
expected_cities = {"barcelona", "berlin", "budapest"}
for val in result.index.values:
assert val in expected_cities
def test_parallel_ray_sample():
"""Tests ray works with sampling.
Doesn't actually check the client - go do that in the UI."""
import ray
from hamilton.plugins import h_ray
special_parallel_sample_strategy = 0.33
kwargs = adapter_kwargs | dict(
dag_name="parallel_test_dag", tags={"sampling_rate": str(special_parallel_sample_strategy)}
)
lifecycle_adapters = [adapters.HamiltonTracker(**kwargs)]
lifecycle_adapters[0].special_parallel_sample_strategy = special_parallel_sample_strategy
remote_executor = h_ray.RayTaskExecutor(None)
# remote_executor = executors.SynchronousLocalTaskExecutor()
shutdown = ray.shutdown
dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(remote_executor) # We only need to specify remote executor
# The local executor just runs it synchronously
.with_modules(tests.resources.parallel_dag)
.with_adapters(*lifecycle_adapters)
.build()
)
data_dir = os.path.join(os.path.dirname(__file__), "resources", "data")
result = dr.execute(final_vars=["statistics_by_city"], inputs={"data_dir": data_dir})[
"statistics_by_city"
]
print(result)
if shutdown:
shutdown()
expected_cities = {"barcelona", "berlin", "budapest"}
for val in result.index.values:
assert val in expected_cities
def test_parallel_ray_sample_error():
"""Tests error returning a sample.
Doesn't actually check the client - go do that in the UI."""
import ray
from hamilton.plugins import h_ray
special_parallel_sample_strategy = 0.0
kwargs = adapter_kwargs | dict(
dag_name="parallel_test_dag", tags={"sampling_rate": str(special_parallel_sample_strategy)}
)
lifecycle_adapters = [adapters.HamiltonTracker(**kwargs)]
lifecycle_adapters[0].special_parallel_sample_strategy = special_parallel_sample_strategy
remote_executor = h_ray.RayTaskExecutor(None)
# remote_executor = executors.SynchronousLocalTaskExecutor()
shutdown = ray.shutdown
dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(remote_executor) # We only need to specify remote executor
# The local executor just runs it synchronously
.with_modules(tests.resources.parallel_dag_error)
.with_adapters(*lifecycle_adapters)
.build()
)
data_dir = os.path.join(os.path.dirname(__file__), "resources", "data")
with pytest.raises(ValueError):
dr.execute(final_vars=["statistics_by_city"], inputs={"data_dir": data_dir})
if shutdown:
shutdown()
if __name__ == "__main__":
# test_adapters()
# logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
# test_async()
# test_parallel_ray_sample()
# test_parallel_ray()
test_parallel_ray_sample_error()