blob: 7fe6df5681b905dcfd895a3a21be88872f01e5af [file]
from typing import Callable
import pytest
from burr.core import Action, Condition, Result, State, default
from burr.core.graph import GraphBuilder, _validate_actions, _validate_transitions
# TODO -- share versions among tests, this is duplicated in test_application.py
class PassedInAction(Action):
def __init__(
self,
reads: list[str],
writes: list[str],
fn: Callable[..., dict],
update_fn: Callable[[dict, State], State],
inputs: list[str],
):
super(PassedInAction, self).__init__()
self._reads = reads
self._writes = writes
self._fn = fn
self._update_fn = update_fn
self._inputs = inputs
def run(self, state: State, **run_kwargs) -> dict:
return self._fn(state, **run_kwargs)
@property
def inputs(self) -> list[str]:
return self._inputs
def update(self, result: dict, state: State) -> State:
return self._update_fn(result, state)
@property
def reads(self) -> list[str]:
return self._reads
@property
def writes(self) -> list[str]:
return self._writes
def test__validate_transitions_correct():
_validate_transitions(
[("counter", "counter", Condition.expr("count < 10")), ("counter", "result", default)],
{"counter", "result"},
)
def test__validate_transitions_missing_action():
with pytest.raises(ValueError, match="not found"):
_validate_transitions(
[
("counter", "counter", Condition.expr("count < 10")),
("counter", "result", default),
],
{"counter"},
)
def test__validate_transitions_redundant_transition():
with pytest.raises(ValueError, match="redundant"):
_validate_transitions(
[
("counter", "counter", Condition.expr("count < 10")),
("counter", "result", default),
("counter", "counter", default), # this is unreachable as we already have a default
],
{"counter", "result"},
)
def test__validate_actions_valid():
_validate_actions([Result("test")])
def test__validate_actions_empty():
with pytest.raises(ValueError, match="at least one"):
_validate_actions([])
base_counter_action = PassedInAction(
reads=["count"],
writes=["count"],
fn=lambda state: {"count": state.get("count", 0) + 1},
update_fn=lambda result, state: state.update(**result),
inputs=[],
)
def test_graph_builder_builds():
graph = (
GraphBuilder()
.with_actions(counter=base_counter_action, result=Result("count"))
.with_transitions(
("counter", "counter", Condition.expr("count < 10")), ("counter", "result")
)
.build()
)
assert len(graph.actions) == 2
assert len(graph.transitions) == 2
def test_graph_builder_get_next_node():
graph = (
GraphBuilder()
.with_actions(counter=base_counter_action, result=Result("count"))
.with_transitions(
("counter", "counter", Condition.expr("count < 10")), ("counter", "result")
)
.build()
)
assert len(graph.actions) == 2
assert len(graph.transitions) == 2
assert graph.get_next_node(None, State({"count": 0}), entrypoint="counter").name == "counter"