blob: 677cd09bb0f4c0041c5b26ba3fb59f7b363b89a7 [file] [log] [blame]
import pytest
from hamilton import driver
from hamilton.execution.executors import SynchronousLocalTaskExecutor
from hamilton.lifecycle.default import GracefulErrorAdapter
from .resources import graceful_parallel
def _make_driver(
test_state: str,
try_parallel=True,
fail_first=False,
allow_injection=True,
):
local_executor = SynchronousLocalTaskExecutor()
dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(graceful_parallel)
.with_remote_executor(local_executor)
.with_adapters(
GracefulErrorAdapter(
error_to_catch=Exception,
sentinel_value=None,
try_all_parallel=try_parallel,
allow_injection=allow_injection,
)
)
.with_config(
{
"test_state": test_state,
"test_front": "bad" if fail_first else "good",
},
)
.build()
)
return dr
@pytest.mark.parametrize(
"test_state, try_parallel, fail_first, expected_value",
[
("early", True, False, [None]),
("middle", True, False, [0.0, 6.0, 12.0, 18.0, 24.0, None]),
("middle", False, False, [None]),
("pass", True, False, [0.0, 6.0, 12.0, 18.0, 24.0, 30.0, None, None, None, None]),
("pass", False, True, [None]),
],
ids=[
"test_early_parallel_fail",
"test_middle_parallel_fail",
"test_middle_parallel_as_single",
"test_inside_block_errors",
"test_pre_parallel_fail",
],
)
def test_parallel_graceful_simple(test_state, try_parallel, fail_first, expected_value) -> None:
# Fail before the anything is yielded: should get 1 failure and skip all
dr = _make_driver(
test_state,
try_parallel,
fail_first,
)
ans = dr.execute(
["distro_gather"],
inputs={"n": 10},
)
assert ans["distro_gather"] == expected_value
def test_parallel_gather_injection() -> None:
dr = _make_driver(
"pass",
try_parallel=True,
fail_first=False,
)
ans = dr.execute(
["distro_end"],
inputs={"n": 10},
)
assert ans["distro_end"] == [
[0.0, None],
[6.0, None],
[12.0, None],
[18.0, None],
[24.0, 13],
[30.0, 16],
[None, 19],
[None, 22],
[None, 25],
[None, 28],
]
# show that disallowing injection skips over the gathering step
dr = _make_driver(
"pass",
try_parallel=True,
fail_first=False,
allow_injection=False,
)
ans = dr.execute(
["distro_end"],
inputs={"n": 10},
)
assert ans["distro_end"] == [
None,
None,
None,
None,
[24.0, 13],
[30.0, 16],
None,
None,
None,
None,
]