| import asyncio |
| import collections |
| import logging |
| import typing |
| from typing import Any, Awaitable, Callable, Dict, Generator, Literal, Optional, Tuple |
| |
| import pytest |
| |
| from burr.core import State |
| from burr.core.action import ( |
| Action, |
| AsyncGenerator, |
| AsyncStreamingAction, |
| Condition, |
| Reducer, |
| Result, |
| SingleStepAction, |
| SingleStepStreamingAction, |
| StreamingAction, |
| action, |
| default, |
| expr, |
| ) |
| from burr.core.application import ( |
| PRIOR_STEP, |
| Application, |
| ApplicationBuilder, |
| ApplicationContext, |
| _adjust_single_step_output, |
| _arun_function, |
| _arun_multi_step_streaming_action, |
| _arun_single_step_action, |
| _arun_single_step_streaming_action, |
| _run_function, |
| _run_multi_step_streaming_action, |
| _run_reducer, |
| _run_single_step_action, |
| _run_single_step_streaming_action, |
| _validate_start, |
| ) |
| from burr.core.graph import Graph, Transition |
| from burr.core.persistence import BaseStatePersister, DevNullPersister, PersistedStateData |
| from burr.lifecycle import ( |
| PostRunStepHook, |
| PostRunStepHookAsync, |
| PreRunStepHook, |
| PreRunStepHookAsync, |
| internal, |
| ) |
| from burr.lifecycle.base import PostApplicationCreateHook |
| from burr.lifecycle.internal import LifecycleAdapterSet |
| from burr.tracking.base import SyncTrackingClient |
| |
| |
| class PassedInAction(Action): |
| def __init__( |
| self, |
| reads: list[str], |
| writes: list[str], |
| fn: Callable[..., dict], |
| update_fn: Callable[[dict, State], State], |
| inputs: list[str], |
| ): |
| super(PassedInAction, self).__init__() |
| self._reads = reads |
| self._writes = writes |
| self._fn = fn |
| self._update_fn = update_fn |
| self._inputs = inputs |
| |
| def run(self, state: State, **run_kwargs) -> dict: |
| return self._fn(state, **run_kwargs) |
| |
| @property |
| def inputs(self) -> list[str]: |
| return self._inputs |
| |
| def update(self, result: dict, state: State) -> State: |
| return self._update_fn(result, state) |
| |
| @property |
| def reads(self) -> list[str]: |
| return self._reads |
| |
| @property |
| def writes(self) -> list[str]: |
| return self._writes |
| |
| |
| class PassedInActionAsync(PassedInAction): |
| def __init__( |
| self, |
| reads: list[str], |
| writes: list[str], |
| fn: Callable[..., Awaitable[dict]], |
| update_fn: Callable[[dict, State], State], |
| inputs: list[str], |
| ): |
| super().__init__(reads=reads, writes=writes, fn=fn, update_fn=update_fn, inputs=inputs) # type: ignore |
| |
| async def run(self, state: State, **run_kwargs) -> dict: |
| return await self._fn(state, **run_kwargs) |
| |
| |
| base_counter_action = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state: {"count": state.get("count", 0) + 1}, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| |
| base_counter_action_with_inputs = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state, additional_increment: { |
| "count": state.get("count", 0) + 1 + additional_increment |
| }, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=["additional_increment"], |
| ) |
| |
| |
| class ActionTracker(PreRunStepHook, PostRunStepHook): |
| def __init__(self): |
| self.pre_called = [] |
| self.post_called = [] |
| |
| def pre_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| inputs: Dict[str, Any], |
| **future_kwargs: Any, |
| ): |
| self.pre_called.append((action.name, locals())) |
| |
| def post_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| result: Optional[Dict[str, Any]], |
| exception: Exception, |
| **future_kwargs: Any, |
| ): |
| locals() |
| self.post_called.append((action.name, locals())) |
| |
| |
| class ActionTrackerAsync(PreRunStepHookAsync, PostRunStepHookAsync): |
| def __init__(self): |
| self.pre_called = [] |
| self.post_called = [] |
| |
| async def pre_run_step(self, *, action: Action, **future_kwargs): |
| await asyncio.sleep(0.0001) |
| self.pre_called.append((action.name, future_kwargs)) |
| |
| async def post_run_step(self, *, action: Action, **future_kwargs): |
| await asyncio.sleep(0.0001) |
| self.post_called.append((action.name, future_kwargs)) |
| |
| |
| async def _counter_update_async(state: State, additional_increment: int = 0) -> dict: |
| await asyncio.sleep(0.0001) # just so we can make this *truly* async |
| # does not matter, but more accurately simulates an async function |
| return {"count": state.get("count", 0) + 1 + additional_increment} |
| |
| |
| base_counter_action_async = PassedInActionAsync( |
| reads=["count"], |
| writes=["count"], |
| fn=_counter_update_async, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| |
| base_counter_action_with_inputs_async = PassedInActionAsync( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state, additional_increment: _counter_update_async( |
| state, additional_increment=additional_increment |
| ), |
| update_fn=lambda result, state: state.update(**result), |
| inputs=["additional_increment"], |
| ) |
| |
| |
| class BrokenStepException(Exception): |
| pass |
| |
| |
| base_broken_action = PassedInAction( |
| reads=[], |
| writes=[], |
| fn=lambda x: exec("raise(BrokenStepException(x))"), |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| base_broken_action_async = PassedInActionAsync( |
| reads=[], |
| writes=[], |
| fn=lambda x: exec("raise(BrokenStepException(x))"), |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| |
| async def incorrect(x): |
| return "not a dict" |
| |
| |
| base_action_incorrect_result_type = PassedInAction( |
| reads=[], |
| writes=[], |
| fn=lambda x: "not a dict", |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| base_action_incorrect_result_type_async = PassedInActionAsync( |
| reads=[], |
| writes=[], |
| fn=incorrect, |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| |
| def test__run_function(): |
| """Tests that we can run a function""" |
| action = base_counter_action |
| state = State({}) |
| result = _run_function(action, state, inputs={}, name=action.name) |
| assert result == {"count": 1} |
| |
| |
| def test__run_function_with_inputs(): |
| """Tests that we can run a function""" |
| action = base_counter_action_with_inputs |
| state = State({}) |
| result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name) |
| assert result == {"count": 2} |
| |
| |
| def test__run_function_cant_run_async(): |
| """Tests that we can't run an async function""" |
| action = base_counter_action_async |
| state = State({}) |
| with pytest.raises(ValueError, match="async"): |
| _run_function(action, state, inputs={}, name=action.name) |
| |
| |
| def test__run_function_incorrect_result_type(): |
| """Tests that we can run an async function""" |
| action = base_action_incorrect_result_type |
| state = State({}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| _run_function(action, state, inputs={}, name=action.name) |
| |
| |
| def test__run_reducer_modifies_state(): |
| """Tests that we can run a reducer and it behaves as expected""" |
| reducer = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=..., |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| state = State({"count": 0}) |
| state = _run_reducer(reducer, state, {"count": 1}, "reducer") |
| assert state["count"] == 1 |
| |
| |
| def test__run_reducer_deletes_state(): |
| """Tests that we can run a reducer that deletes an item from state""" |
| reducer = PassedInAction( |
| reads=["count"], |
| writes=[], # TODO -- figure out how we can better know that it deletes items...ß |
| fn=..., |
| update_fn=lambda result, state: state.wipe(delete=["count"]), |
| inputs=[], |
| ) |
| state = State({"count": 0}) |
| state = _run_reducer(reducer, state, {}, "deletion_reducer") |
| assert "count" not in state |
| |
| |
| async def test__arun_function(): |
| """Tests that we can run an async function""" |
| action = base_counter_action_async |
| state = State({}) |
| result = await _arun_function(action, state, inputs={}, name=action.name) |
| assert result == {"count": 1} |
| |
| |
| async def test__arun_function_incorrect_result_type(): |
| """Tests that we can run an async function""" |
| action = base_action_incorrect_result_type_async |
| state = State({}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| await _arun_function(action, state, inputs={}, name=action.name) |
| |
| |
| async def test__arun_function_with_inputs(): |
| """Tests that we can run an async function""" |
| action = base_counter_action_with_inputs_async |
| state = State({}) |
| result = await _arun_function( |
| action, state, inputs={"additional_increment": 1}, name=action.name |
| ) |
| assert result == {"count": 2} |
| |
| |
| def test_run_reducer_errors_missing_writes(): |
| class BrokenReducer(Reducer): |
| def update(self, result: dict, state: State) -> State: |
| return state.update(present_value=1) |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| reducer = BrokenReducer() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| _run_reducer(reducer, state, {}, "broken_reducer") |
| |
| |
| def test_run_single_step_action_errors_missing_writes(): |
| class BrokenAction(SingleStepAction): |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| _run_single_step_action(action, state, inputs={}) |
| |
| |
| async def test_arun_single_step_action_errors_missing_writes(): |
| class BrokenAction(SingleStepAction): |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| await asyncio.sleep(0.0001) # just so we can make this *truly* async |
| return {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| await _arun_single_step_action(action, state, inputs={}) |
| |
| |
| def test_run_single_step_streaming_action_errors_missing_write(): |
| class BrokenAction(SingleStepStreamingAction): |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| yield {}, None |
| yield {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| gen = _run_single_step_streaming_action(action, state, inputs={}) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| async def test_run_single_step_streaming_action_errors_missing_write_async(): |
| class BrokenAction(SingleStepStreamingAction): |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| yield {}, None |
| yield {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| gen = _arun_single_step_streaming_action(action, state, inputs={}) |
| [result async for result in gen] # exhaust the generator |
| |
| |
| def test_run_multi_step_streaming_action_errors_missing_write(): |
| class BrokenAction(StreamingAction): |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: |
| yield {} |
| yield {"present_value": 1} |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(present_value=1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| gen = _run_multi_step_streaming_action(action, state, inputs={}) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| class SingleStepCounter(SingleStepAction): |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| result = {"count": state["count"] + 1 + sum([0] + list(run_kwargs.values()))} |
| return result, state.update(**result).append(tracker=result["count"]) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepCounterWithInputs(SingleStepCounter): |
| @property |
| def inputs(self) -> list[str]: |
| return ["additional_increment"] |
| |
| |
| class SingleStepActionIncorrectResultType(SingleStepAction): |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return "not a dict", state |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| class SingleStepActionIncorrectResultTypeAsync(SingleStepActionIncorrectResultType): |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return "not a dict", state |
| |
| |
| class SingleStepCounterAsync(SingleStepCounter): |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| await asyncio.sleep(0.0001) # just so we can make this *truly* async |
| return super(SingleStepCounterAsync, self).run_and_update(state, **run_kwargs) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepCounterWithInputsAsync(SingleStepCounterAsync): |
| @property |
| def inputs(self) -> list[str]: |
| return ["additional_increment"] |
| |
| |
| class StreamingCounter(StreamingAction): |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: |
| if "steps_per_count" in run_kwargs: |
| steps_per_count = run_kwargs["granularity"] |
| else: |
| steps_per_count = 10 |
| count = state["count"] |
| for i in range(steps_per_count): |
| yield {"count": count + ((i + 1) / 10)} |
| yield {"count": count + 1} |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| class AsyncStreamingCounter(AsyncStreamingAction): |
| async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: |
| if "steps_per_count" in run_kwargs: |
| steps_per_count = run_kwargs["granularity"] |
| else: |
| steps_per_count = 10 |
| count = state["count"] |
| for i in range(steps_per_count): |
| await asyncio.sleep(0.01) |
| yield {"count": count + (i + 1) / 10} |
| await asyncio.sleep(0.01) |
| yield {"count": count + 1} |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| class SingleStepStreamingCounter(SingleStepStreamingAction): |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| steps_per_count = run_kwargs.get("granularity", 10) |
| count = state["count"] |
| for i in range(steps_per_count): |
| yield {"count": count + ((i + 1) / 10)}, None |
| yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepStreamingCounterAsync(SingleStepStreamingAction): |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| steps_per_count = run_kwargs.get("granularity", 10) |
| count = state["count"] |
| for i in range(steps_per_count): |
| await asyncio.sleep(0.01) |
| yield {"count": count + ((i + 1) / 10)}, None |
| await asyncio.sleep(0.01) |
| yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class StreamingActionIncorrectResultType(StreamingAction): |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]: |
| yield {} |
| yield "not a dict" |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state |
| |
| |
| class StreamingActionIncorrectResultTypeAsync(AsyncStreamingAction): |
| async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: |
| yield {} |
| yield "not a dict" |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state |
| |
| |
| class StreamingSingleStepActionIncorrectResultType(SingleStepStreamingAction): |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| yield {}, State |
| yield "not a dict", state |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| class StreamingSingleStepActionIncorrectResultTypeAsync(SingleStepStreamingAction): |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> typing.AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| yield {}, None |
| yield "not a dict", state |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| base_single_step_counter = SingleStepCounter() |
| base_single_step_counter_async = SingleStepCounterAsync() |
| base_single_step_counter_with_inputs = SingleStepCounterWithInputs() |
| base_single_step_counter_with_inputs_async = SingleStepCounterWithInputsAsync() |
| |
| base_streaming_counter = StreamingCounter() |
| base_streaming_single_step_counter = SingleStepStreamingCounter() |
| |
| base_streaming_counter_async = AsyncStreamingCounter() |
| base_streaming_single_step_counter_async = SingleStepStreamingCounterAsync() |
| |
| base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType() |
| base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync() |
| |
| |
| def test__run_single_step_action(): |
| action = base_single_step_counter.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = _run_single_step_action(action, state, inputs={}) |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| result, state = _run_single_step_action(action, state, inputs={}) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [1, 2]} |
| |
| |
| def test__run_single_step_action_incorrect_result_type(): |
| action = base_single_step_action_incorrect_result_type.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| _run_single_step_action(action, state, inputs={}) |
| |
| |
| async def test__arun_single_step_action_incorrect_result_type(): |
| action = base_single_step_action_incorrect_result_type_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| await _arun_single_step_action(action, state, inputs={}) |
| |
| |
| def test__run_single_step_action_with_inputs(): |
| action = base_single_step_counter_with_inputs.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) |
| assert result == {"count": 4} |
| assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} |
| |
| |
| async def test__arun_single_step_action(): |
| action = base_single_step_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = await _arun_single_step_action(action, state, inputs={}) |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| result, state = await _arun_single_step_action(action, state, inputs={}) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [1, 2]} |
| |
| |
| async def test__arun_single_step_action_with_inputs(): |
| action = base_single_step_counter_with_inputs_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = await _arun_single_step_action( |
| action, state, inputs={"additional_increment": 1} |
| ) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| result, state = await _arun_single_step_action( |
| action, state, inputs={"additional_increment": 1} |
| ) |
| assert result == {"count": 4} |
| assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} |
| |
| |
| class SingleStepActionWithDeletion(SingleStepAction): |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return {}, state.wipe(delete=["to_delete"]) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| def test__run_single_step_action_deletes_state(): |
| action = SingleStepActionWithDeletion() |
| state = State({"to_delete": 0}) |
| result, state = _run_single_step_action(action, state, inputs={}) |
| assert "to_delete" not in state |
| |
| |
| def test__run_multistep_streaming_action(): |
| action = base_streaming_counter.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_multi_step_streaming_action(action, state, inputs={}) |
| last_result = -1 |
| result = None |
| for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit floating poit comparison problems |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| async def test__run_multistep_streaming_action_async(): |
| action = base_streaming_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_multi_step_streaming_action(action, state, inputs={}) |
| last_result = -1 |
| result = None |
| async for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit floating poit comparison problems |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| def test__run_streaming_action_incorrect_result_type(): |
| action = StreamingActionIncorrectResultType() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _run_multi_step_streaming_action(action, state, inputs={}) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| async def test__run_streaming_action_incorrect_result_type_async(): |
| action = StreamingActionIncorrectResultTypeAsync() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _arun_multi_step_streaming_action(action, state, inputs={}) |
| async for _ in gen: |
| pass |
| |
| |
| def test__run_single_step_streaming_action_incorrect_result_type(): |
| action = StreamingSingleStepActionIncorrectResultType() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _run_single_step_streaming_action(action, state, inputs={}) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| async def test__run_single_step_streaming_action_incorrect_result_type_async(): |
| action = StreamingSingleStepActionIncorrectResultTypeAsync() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _arun_single_step_streaming_action(action, state, inputs={}) |
| _ = [item async for item in gen] |
| |
| |
| def test__run_single_step_streaming_action(): |
| action = base_streaming_single_step_counter.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_single_step_streaming_action(action, state, inputs={}) |
| last_result = -1 |
| result, state = None, None |
| for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit comparison issues |
| # This is because we get to the last one, which is the final result |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| async def test__run_single_step_streaming_action_async(): |
| async_action = base_streaming_single_step_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_single_step_streaming_action(async_action, state, inputs={}) |
| last_result = -1 |
| result, state = None, None |
| async for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit comparison issues |
| # This is because we get to the last one, which is the final result |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return {}, state.wipe(delete=["to_delete"]) |
| |
| |
| async def test__arun_single_step_action_deletes_state(): |
| action = SingleStepActionWithDeletionAsync() |
| state = State({"to_delete": 0}) |
| result, state = await _arun_single_step_action(action, state, inputs={}) |
| assert "to_delete" not in state |
| |
| |
| def test_app_step(): |
| """Tests that we can run a step in an app""" |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = app.step() |
| assert app.sequence_id == 1 |
| assert action.name == "counter" |
| assert result == {"count": 1} |
| assert state[PRIOR_STEP] == "counter" # internal contract, not part of the public API |
| |
| |
| def test_app_step_with_inputs(): |
| """Tests that we can run a step in an app""" |
| counter_action = base_single_step_counter_with_inputs.with_name("counter") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = app.step(inputs={"additional_increment": 1}) |
| assert action.name == "counter" |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| |
| |
| def test_app_step_with_inputs_missing(): |
| """Tests that we can run a step in an app""" |
| counter_action = base_single_step_counter_with_inputs.with_name("counter") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| with pytest.raises(ValueError, match="missing required inputs"): |
| app.step(inputs={}) |
| |
| |
| def test_app_step_broken(caplog): |
| """Tests that we can run a step in an app""" |
| broken_action = base_broken_action.with_name("broken_action_unique_name") |
| app = Application( |
| state=State({}), |
| entrypoint="broken_action_unique_name", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[broken_action], |
| transitions=[Transition(broken_action, broken_action, default)], |
| ), |
| ) |
| with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now |
| with pytest.raises(BrokenStepException): |
| app.step() |
| assert "broken_action_unique_name" in caplog.text |
| |
| |
| def test_app_step_done(): |
| """Tests that when we cannot run a step, we return None""" |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[], |
| ), |
| ) |
| app.step() |
| assert app.step() is None |
| |
| |
| async def test_app_astep(): |
| """Tests that we can run an async step in an app""" |
| counter_action = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = await app.astep() |
| assert app.sequence_id == 1 |
| assert action.name == "counter_async" |
| assert result == {"count": 1} |
| assert state[PRIOR_STEP] == "counter_async" # internal contract, not part of the public API |
| |
| |
| async def test_app_astep_with_inputs(): |
| """Tests that we can run an async step in an app""" |
| counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = await app.astep(inputs={"additional_increment": 1}) |
| assert action.name == "counter_async" |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| |
| |
| async def test_app_astep_with_inputs_missing(): |
| """Tests that we can run an async step in an app""" |
| counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| with pytest.raises(ValueError, match="missing required inputs"): |
| await app.astep(inputs={}) |
| |
| |
| async def test_app_astep_broken(caplog): |
| """Tests that we can run a step in an app""" |
| broken_action = base_broken_action_async.with_name("broken_action_unique_name") |
| app = Application( |
| state=State({}), |
| entrypoint="broken_action_unique_name", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[broken_action], |
| transitions=[Transition(broken_action, broken_action, default)], |
| ), |
| ) |
| with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now |
| with pytest.raises(BrokenStepException): |
| await app.astep() |
| assert "broken_action_unique_name" in caplog.text |
| |
| |
| async def test_app_astep_done(): |
| """Tests that when we cannot run a step, we return None""" |
| counter_action = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[], |
| ), |
| ) |
| await app.astep() |
| assert await app.astep() is None |
| |
| |
| # internal API |
| def test_app_many_steps(): |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result = None, None |
| for i in range(100): |
| action, result, state = app.step() |
| assert action.name == "counter" |
| assert result == {"count": 100} |
| |
| |
| async def test_app_many_a_steps(): |
| counter_action = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result = None, None |
| for i in range(100): |
| action, result, state = await app.astep() |
| assert action.name == "counter_async" |
| assert result == {"count": 100} |
| |
| |
| def test_iterate(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| res = [] |
| gen = app.iterate(halt_after=["result"]) |
| counter = 0 |
| try: |
| while True: |
| action, result, state = next(gen) |
| if action.name == "counter": |
| assert state["count"] == counter + 1 |
| assert result["count"] == state["count"] |
| counter = result["count"] |
| else: |
| res.append(result) |
| assert state["count"] == 10 |
| assert result["count"] == 10 |
| except StopIteration as e: |
| stop_iteration_error = e |
| generator_result = stop_iteration_error.value |
| action, result, state = generator_result |
| assert state["count"] == 10 |
| assert result["count"] == 10 |
| assert app.sequence_id == 11 |
| |
| |
| def test_iterate_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 2")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.iterate( |
| halt_after=["result"], inputs={"additional_increment": 10} |
| ) # make it go quicly to the end |
| while True: |
| try: |
| action, result, state = next(gen) |
| except StopIteration as e: |
| a, r, s = e.value |
| assert r["count"] == 11 # 1 + 10, for the first one |
| break |
| |
| |
| async def test_aiterate(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.aiterate(halt_after=["result"]) |
| assert app.sequence_id == 0 |
| counter = 0 |
| # Note that we use an async-for loop cause the API is different, this doesn't |
| # return anything (async generators are not allowed to). |
| async for action_, result, state in gen: |
| print("si", app.sequence_id, action_.name, state) |
| if action_.name == "counter": |
| assert state["count"] == result["count"] == counter + 1 |
| counter = result["count"] |
| else: |
| assert state["count"] == result["count"] == 10 |
| assert app.sequence_id == 11 |
| |
| |
| async def test_aiterate_halt_before(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.aiterate(halt_before=["result"]) |
| counter = 0 |
| # Note that we use an async-for loop cause the API is different, this doesn't |
| # return anything (async generators are not allowed to). |
| async for action_, result, state in gen: |
| if action_.name == "counter": |
| assert state["count"] == counter + 1 |
| counter = result["count"] |
| else: |
| assert result is None |
| assert state["count"] == 10 |
| |
| |
| async def test_app_aiterate_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.aiterate(halt_after=["result"], inputs={"additional_increment": 10}) |
| async for action_, result, state in gen: |
| if action_.name == "counter": |
| assert result["count"] == state["count"] == 11 |
| else: |
| assert state["count"] == result["count"] == 11 |
| |
| |
| def test_run(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action, result, state = app.run(halt_after=["result"]) |
| assert state["count"] == 10 |
| assert result["count"] == 10 |
| |
| |
| def test_run_halt_before(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action, result, state = app.run(halt_before=["result"]) |
| assert state["count"] == 10 |
| assert result is None |
| assert action.name == "result" |
| |
| |
| def test_run_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 10}) |
| assert action_.name == "result" |
| assert state["count"] == result["count"] == 11 |
| |
| |
| def test_run_with_inputs_multiple_actions(): |
| """Tests that inputs aren't popped off and are passed through to multiple actions.""" |
| result_action = Result("count").with_name("result") |
| counter_action1 = base_counter_action_with_inputs.with_name("counter1") |
| counter_action2 = base_counter_action_with_inputs.with_name("counter2") |
| app = Application( |
| state=State({}), |
| entrypoint="counter1", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action1, counter_action2, result_action], |
| transitions=[ |
| Transition(counter_action1, counter_action1, Condition.expr("count < 10")), |
| Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), |
| Transition(counter_action2, counter_action2, Condition.expr("count < 20")), |
| Transition(counter_action2, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 8}) |
| assert action_.name == "result" |
| assert state["count"] == result["count"] == 27 |
| assert state["__SEQUENCE_ID"] == 4 |
| |
| |
| async def test_arun(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun(halt_after=["result"]) |
| assert state["count"] == result["count"] == 10 |
| assert action_.name == "result" |
| |
| |
| async def test_arun_halt_before(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun(halt_before=["result"]) |
| assert state["count"] == 10 |
| assert result is None |
| assert action_.name == "result" |
| |
| |
| async def test_arun_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun( |
| halt_after=["result"], inputs={"additional_increment": 10} |
| ) |
| assert state["count"] == result["count"] == 11 |
| assert action_.name == "result" |
| |
| |
| async def test_arun_with_inputs_multiple_actions(): |
| result_action = Result("count").with_name("result") |
| counter_action1 = base_counter_action_with_inputs_async.with_name("counter1") |
| counter_action2 = base_counter_action_with_inputs_async.with_name("counter2") |
| app = Application( |
| state=State({}), |
| entrypoint="counter1", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action1, counter_action2, result_action], |
| transitions=[ |
| Transition(counter_action1, counter_action1, Condition.expr("count < 10")), |
| Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), |
| Transition(counter_action2, counter_action2, Condition.expr("count < 20")), |
| Transition(counter_action2, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun( |
| halt_after=["result"], inputs={"additional_increment": 8} |
| ) |
| assert state["count"] == result["count"] == 27 |
| assert action_.name == "result" |
| assert state["__SEQUENCE_ID"] == 4 |
| |
| |
| async def test_app_a_run_async_and_sync(): |
| result_action = Result("count").with_name("result") |
| counter_action_sync = base_counter_action_async.with_name("counter_sync") |
| counter_action_async = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_sync", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action_sync, counter_action_async, result_action], |
| transitions=[ |
| Transition(counter_action_sync, counter_action_async, Condition.expr("count < 20")), |
| Transition(counter_action_async, counter_action_sync, default), |
| Transition(counter_action_sync, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun(halt_after=["result"]) |
| assert state["count"] > 20 |
| assert result["count"] > 20 |
| |
| |
| def test_stream_result_halt_after_unique_ordered_sequence_id(): |
| action_tracker = ActionTracker() |
| counter_action = base_streaming_counter.with_name("counter") |
| counter_action_2 = base_streaming_counter.with_name("counter_2") |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = app.stream_result(halt_after=["counter_2"]) |
| results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| |
| |
| async def test_astream_result_halt_after_unique_ordered_sequence_id(): |
| action_tracker = ActionTracker() |
| counter_action = base_streaming_counter_async.with_name("counter") |
| counter_action_2 = base_streaming_counter_async.with_name("counter_2") |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_async_container = await app.astream_result(halt_after=["counter_2"]) |
| results = [ |
| item async for item in streaming_async_container |
| ] # this should just have the intermediate results |
| # results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = await streaming_async_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| |
| |
| def test_stream_result_halt_after_run_through_streaming(): |
| """Tests that we can pass through streaming results, |
| fully realize them, then get to the streaming results at the end and return the stream""" |
| action_tracker = ActionTracker() |
| counter_action = base_streaming_single_step_counter.with_name("counter") |
| counter_action_2 = base_streaming_single_step_counter.with_name("counter_2") |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = app.stream_result(halt_after=["counter_2"]) |
| results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| |
| |
| async def test_astream_result_halt_after_run_through_streaming(): |
| action_tracker = ActionTracker() |
| counter_action = base_streaming_single_step_counter_async.with_name("counter") |
| counter_action_2 = base_streaming_single_step_counter_async.with_name("counter_2") |
| assert counter_action.is_async() |
| assert counter_action_2.is_async() |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = await app.astream_result(halt_after=["counter_2"]) |
| results = [ |
| item async for item in streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 10 |
| result, state = await streaming_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| |
| |
| def test_stream_result_halt_after_run_through_non_streaming(): |
| """Tests what happens when we have an app that runs through non-streaming |
| results before hitting a final streaming result specified by halt_after""" |
| action_tracker = ActionTracker() |
| counter_non_streaming = base_counter_action.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter.with_name("counter_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = app.stream_result(halt_after=["counter_streaming"]) |
| results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| |
| |
| async def test_astream_result_halt_after_run_through_non_streaming(): |
| action_tracker = ActionTracker() |
| counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter_async.with_name("counter_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action_, async_streaming_container = await app.astream_result(halt_after=["counter_streaming"]) |
| results = [ |
| item async for item in async_streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 10 |
| result, state = await async_streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| |
| |
| def test_stream_result_halt_after_run_through_final_non_streaming(): |
| """Tests that we can pass through non-streaming results when streaming is called""" |
| action_tracker = ActionTracker() |
| counter_non_streaming = base_counter_action.with_name("counter_non_streaming") |
| counter_final_non_streaming = base_counter_action.with_name("counter_final_non_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_final_non_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_final_non_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = app.stream_result(halt_after=["counter_final_non_streaming"]) |
| results = list(streaming_container) |
| assert len(results) == 0 # nothing to steram |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| |
| |
| async def test_astream_result_halt_after_run_through_final_streaming(): |
| """Tests that we can pass through non-streaming results when streaming is called""" |
| action_tracker = ActionTracker() |
| counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") |
| counter_final_non_streaming = base_counter_action_async.with_name("counter_final_non_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_final_non_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_final_non_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = await app.astream_result( |
| halt_after=["counter_final_non_streaming"] |
| ) |
| results = [ |
| item async for item in streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 0 # nothing to stream |
| result, state = await streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| |
| |
| def test_stream_result_halt_before(): |
| action_tracker = ActionTracker() |
| counter_non_streaming = base_counter_action.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter.with_name("counter_final") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| partition_key="test", |
| uid="test-123", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = app.stream_result(halt_after=[], halt_before=["counter_final"]) |
| results = list(streaming_container) |
| assert len(results) == 0 # nothing to steram |
| result, state = streaming_container.get() |
| assert action.name == "counter_final" # halt before this one |
| assert result is None |
| assert state["count"] == 10 |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| |
| |
| async def test_astream_result_halt_before(): |
| action_tracker = ActionTracker() |
| counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter_async.with_name("counter_final") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| partition_key="test", |
| uid="test-123", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = await app.astream_result( |
| halt_after=[], halt_before=["counter_final"] |
| ) |
| results = [ |
| item async for item in streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 0 # nothing to stream |
| result, state = await streaming_container.get() |
| assert action.name == "counter_final" # halt before this one |
| assert result is None |
| assert state["count"] == 10 |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| |
| |
| def test_app_set_state(): |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State(), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| assert "counter" not in app.state # initial value |
| app.step() |
| assert app.state["count"] == 1 # updated value |
| state = app.state |
| app.update_state(state.update(count=2)) |
| assert app.state["count"] == 2 # updated value |
| |
| |
| def test_app_get_next_step(): |
| counter_action_1 = base_counter_action.with_name("counter_1") |
| counter_action_2 = base_counter_action.with_name("counter_2") |
| counter_action_3 = base_counter_action.with_name("counter_3") |
| app = Application( |
| state=State(), |
| entrypoint="counter_1", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action_1, counter_action_2, counter_action_3], |
| transitions=[ |
| Transition(counter_action_1, counter_action_2, default), |
| Transition(counter_action_2, counter_action_3, default), |
| Transition(counter_action_3, counter_action_1, default), |
| ], |
| ), |
| ) |
| # uninitialized -- counter_1 |
| assert app.get_next_action().name == "counter_1" |
| app.step() |
| # ran counter_1 -- counter_2 |
| assert app.get_next_action().name == "counter_2" |
| app.step() |
| # ran counter_2 -- counter_3 |
| assert app.get_next_action().name == "counter_3" |
| app.step() |
| # ran counter_3 -- back to counter_1 |
| assert app.get_next_action().name == "counter_1" |
| |
| |
| def test_application_builder_complete(): |
| app = ( |
| ApplicationBuilder() |
| .with_state(count=0) |
| .with_actions(counter=base_counter_action, result=Result("count")) |
| .with_transitions( |
| ("counter", "counter", Condition.expr("count < 10")), ("counter", "result") |
| ) |
| .with_entrypoint("counter") |
| .build() |
| ) |
| graph = app.graph |
| assert len(graph.actions) == 2 |
| assert len(graph.transitions) == 2 |
| assert app.get_next_action().name == "counter" |
| |
| |
| def test__validate_start_valid(): |
| _validate_start("counter", {"counter", "result"}) |
| |
| |
| def test__validate_start_not_found(): |
| with pytest.raises(ValueError, match="not found"): |
| _validate_start("counter", {"result"}) |
| |
| |
| def test__adjust_single_step_output_result_and_state(): |
| state = State({"count": 1}) |
| result = {"count": 1} |
| assert _adjust_single_step_output((result, state), "test_action") == (result, state) |
| |
| |
| def test__adjust_single_step_output_just_state(): |
| state = State({"count": 1}) |
| assert _adjust_single_step_output(state, "test_action") == ({}, state) |
| |
| |
| def test__adjust_single_step_output_errors_incorrect_type(): |
| state = "foo" |
| with pytest.raises(ValueError, match="must return either"): |
| _adjust_single_step_output(state, "test_action") |
| |
| |
| def test__adjust_single_step_output_errors_incorrect_result_type(): |
| state = State() |
| result = "bar" |
| with pytest.raises(ValueError, match="non-dict"): |
| _adjust_single_step_output((state, result), "test_action") |
| |
| |
| def test_application_builder_unset(): |
| with pytest.raises(ValueError): |
| ApplicationBuilder().build() |
| |
| |
| def test_application_run_step_hooks_sync(): |
| tracker = ActionTracker() |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(tracker), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| app.run(halt_after=["result"]) |
| assert set(dict(tracker.pre_called).keys()) == {"counter", "result"} |
| assert set(dict(tracker.post_called).keys()) == {"counter", "result"} |
| # assert sequence id is incremented |
| assert tracker.pre_called[0][1]["sequence_id"] == 1 |
| assert tracker.post_called[0][1]["sequence_id"] == 1 |
| assert { |
| "action", |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| }.issubset(set(tracker.pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(tracker.post_called[0][1].keys())) |
| assert len(tracker.pre_called) == 11 |
| assert len(tracker.post_called) == 11 |
| # quick inclusion to ensure that the action is not called when we're done running |
| assert app.step() is None # should be None |
| |
| |
| async def test_application_run_step_hooks_async(): |
| tracker = ActionTrackerAsync() |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(tracker), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| await app.arun(halt_after=["result"]) |
| assert set(dict(tracker.pre_called).keys()) == {"counter", "result"} |
| assert set(dict(tracker.post_called).keys()) == {"counter", "result"} |
| # assert sequence id is incremented |
| assert tracker.pre_called[0][1]["sequence_id"] == 1 |
| assert tracker.post_called[0][1]["sequence_id"] == 1 |
| assert { |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| }.issubset(set(tracker.pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(tracker.post_called[0][1].keys())) |
| assert len(tracker.pre_called) == 11 |
| assert len(tracker.post_called) == 11 |
| |
| |
| async def test_application_run_step_runs_hooks(): |
| hooks = [ActionTracker(), ActionTrackerAsync()] |
| |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(*hooks), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[ |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| await app.astep() |
| assert len(hooks[0].pre_called) == 1 |
| assert len(hooks[0].post_called) == 1 |
| # assert sequence id is incremented |
| assert hooks[0].pre_called[0][1]["sequence_id"] == 1 |
| assert hooks[0].post_called[0][1]["sequence_id"] == 1 |
| assert { |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| "action", |
| }.issubset(set(hooks[0].pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| }.issubset(set(hooks[1].pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(hooks[0].post_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(hooks[1].post_called[0][1].keys())) |
| assert len(hooks[1].pre_called) == 1 |
| assert len(hooks[1].post_called) == 1 |
| |
| |
| def test_application_post_application_create_hook(): |
| class PostApplicationCreateTracker(PostApplicationCreateHook): |
| def __init__(self): |
| self.called_args = None |
| self.call_count = 0 |
| |
| def post_application_create(self, **kwargs): |
| self.called_args = kwargs |
| self.call_count += 1 |
| |
| tracker = PostApplicationCreateTracker() |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(tracker), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| assert "state" in tracker.called_args |
| assert "application_graph" in tracker.called_args |
| assert tracker.call_count == 1 |
| |
| |
| async def test_application_gives_graph(): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| graph = app.graph |
| assert len(graph.actions) == 2 |
| assert len(graph.transitions) == 2 |
| assert graph.entrypoint.name == "counter" |
| |
| |
| def test_application_builder_initialize_does_not_allow_state_setting(): |
| with pytest.raises(ValueError, match="Cannot call initialize_from"): |
| ApplicationBuilder().with_entrypoint("foo").with_state(**{"foo": "bar"}).initialize_from( |
| DevNullPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| ) |
| |
| |
| class BrokenPersister(BaseStatePersister): |
| """Broken persistor.""" |
| |
| def load( |
| self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs |
| ) -> Optional[PersistedStateData]: |
| return dict( |
| partition_key="key", |
| app_id="id", |
| sequence_id=0, |
| position="foo", |
| state=None, |
| created_at="", |
| status="completed", |
| ) |
| |
| def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: |
| return [] |
| |
| def save( |
| self, |
| partition_key: Optional[str], |
| app_id: str, |
| sequence_id: int, |
| position: str, |
| state: State, |
| status: Literal["completed", "failed"], |
| **kwargs, |
| ): |
| return |
| |
| |
| def test_application_builder_initialize_raises_on_broken_persistor(): |
| """Persisters should return None when there is no state to be loaded and the default used.""" |
| with pytest.raises(ValueError, match="but value for state was None"): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| BrokenPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| ) |
| .build() |
| ) |
| |
| |
| def test_application_builder_assigns_correct_actions_with_dual_api(): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count") |
| |
| @action(reads=[], writes=[]) |
| def test_action(state: State) -> State: |
| return state |
| |
| app = ( |
| ApplicationBuilder() |
| .with_state(count=0) |
| .with_actions(counter_action, test_action, result=result_action) |
| .with_transitions() |
| .with_entrypoint("counter") |
| .build() |
| ) |
| graph = app.graph |
| assert {a.name for a in graph.actions} == {"counter", "result", "test_action"} |
| |
| |
| def test__validate_halt_conditions(): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count") |
| |
| @action(reads=[], writes=[]) |
| def test_action(state: State) -> State: |
| return state |
| |
| app = ( |
| ApplicationBuilder() |
| .with_state(count=0) |
| .with_actions(counter_action, test_action, result=result_action) |
| .with_transitions() |
| .with_entrypoint("counter") |
| .build() |
| ) |
| with pytest.raises(ValueError, match="(?=.*no_exist_1)(?=.*no_exist_2)"): |
| app._validate_halt_conditions(halt_after=["no_exist_1"], halt_before=["no_exist_2"]) |
| |
| |
| def test_application_builder_initialize_raises_on_fork_app_id_not_provided(): |
| """Can't pass in fork_from* without an app_id.""" |
| with pytest.raises(ValueError, match="If you set fork_from_partition_key"): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| BrokenPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| fork_from_sequence_id=1, |
| fork_from_partition_key="foo-bar", |
| ) |
| .build() |
| ) |
| |
| |
| class DummyPersister(BaseStatePersister): |
| """Dummy persistor.""" |
| |
| def load( |
| self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs |
| ) -> Optional[PersistedStateData]: |
| return PersistedStateData( |
| partition_key="user123", |
| app_id="123", |
| sequence_id=5, |
| position="counter", |
| state=State({"count": 5}), |
| created_at="", |
| status="completed", |
| ) |
| |
| def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: |
| return ["123"] |
| |
| def save( |
| self, |
| partition_key: Optional[str], |
| app_id: str, |
| sequence_id: int, |
| position: str, |
| state: State, |
| status: Literal["completed", "failed"], |
| **kwargs, |
| ): |
| return |
| |
| |
| def test_application_builder_initialize_fork_errors_on_same_app_id(): |
| """Tests that we can't have an app_id and fork_from_app_id that's the same""" |
| with pytest.raises(ValueError, match="Cannot fork and save"): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| DummyPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| fork_from_app_id="123", |
| fork_from_partition_key="user123", |
| ) |
| .with_identifiers(app_id="123") |
| .build() |
| ) |
| |
| |
| def test_application_builder_initialize_fork_app_id_happy_pth(): |
| """Tests that forking properly works""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| old_app_id = "123" |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| DummyPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="counter", |
| fork_from_app_id=old_app_id, |
| fork_from_partition_key="user123", |
| ) |
| .with_identifiers(app_id="test123") |
| .build() |
| ) |
| assert app.uid != old_app_id |
| assert app.state == State({"count": 5, "__PRIOR_STEP": "counter", "__SEQUENCE_ID": 5}) |
| assert app.parent_pointer.app_id == old_app_id |
| |
| |
| class NoOpTracker(SyncTrackingClient): |
| def __init__(self, unique_id: str): |
| self.unique_id = unique_id |
| |
| def post_application_create(self, **future_kwargs: Any): |
| pass |
| |
| def pre_run_step(self, **future_kwargs: Any): |
| pass |
| |
| def post_run_step(self, **future_kwargs: Any): |
| pass |
| |
| def pre_start_span(self, **future_kwargs: Any): |
| pass |
| |
| def post_end_span(self, **future_kwargs: Any): |
| pass |
| |
| |
| def test_application_exposes_app_context(): |
| """Tests that we can get the context from the application correctly""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| context = app.context |
| assert context.app_id == "test123" |
| assert context.partition_key == "user123" |
| assert context.sequence_id == 5 |
| assert context.tracker.unique_id == "unique_tracker_name" |
| |
| |
| def test_application_exposes_app_context_through_context_manager_sync(): |
| """Tests that we can get the context from the application correctly""" |
| |
| @action(reads=["count"], writes=["count"]) |
| def counter(state: State) -> State: |
| app_context = ApplicationContext.get() |
| assert app_context is not None |
| assert app_context.tracker is not None # NoOpTracker is used |
| assert isinstance(app_context.sequence_id, int) |
| assert app_context.app_id == "test123" |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(result_action, counter=counter) |
| .with_transitions(("counter", "result", default)) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| app.run(halt_after=["result"]) |
| |
| |
| async def test_application_exposes_app_context_through_context_manager_async(): |
| """Tests that we can get the context from the application correctly""" |
| |
| @action(reads=["count"], writes=["count"]) |
| async def counter(state: State) -> State: |
| app_context = ApplicationContext.get() |
| assert app_context is not None |
| assert app_context.tracker is not None # NoOpTracker is used |
| assert isinstance(app_context.sequence_id, int) |
| assert app_context.app_id == "test123" |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(result_action, counter=counter) |
| .with_transitions(("counter", "result", default)) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| await app.arun(halt_after=["result"]) |
| |
| |
| def test_application_passes_context_when_declared(): |
| """Tests that the context is passed to the function correctly""" |
| context_list = [] |
| |
| @action(reads=["count"], writes=["count"]) |
| def context_counter(state: State, __context: ApplicationContext) -> State: |
| context_list.append(__context) |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter=context_counter, result=result_action) |
| .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| app.run(halt_after=["result"]) |
| sequence_ids = [context.sequence_id for context in context_list] |
| assert sequence_ids == list(range(6, 16)) |
| app_ids = set(context.app_id for context in context_list) |
| assert app_ids == {"test123"} |
| trackers = set(context.tracker.unique_id for context in context_list) |
| assert trackers == {"unique_tracker_name"} |
| |
| |
| def test_optional_context_in_dependency_factories(): |
| """Tests that the context is passed to the function correctly when nulled out. |
| TODO -- get this to test without instantiating an application through the builder -- |
| this is slightly overkill for a bit of code""" |
| context_list = [] |
| |
| @action(reads=["count"], writes=["count"]) |
| def context_counter(state: State, __context: ApplicationContext = None) -> State: |
| context_list.append(__context) |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter=context_counter, result=result_action) |
| .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| inputs = app._process_inputs({}, app.get_next_action()) |
| assert "__context" in inputs # it should be there |
| assert inputs["__context"] is not None # it should not be None |
| assert inputs["__context"].app_id == "test123" # it should be the correct context |
| assert ( |
| inputs["__context"].tracker.unique_id == "unique_tracker_name" |
| ) # it should be the correct context |
| |
| |
| def test_application_with_no_spawning_parent(): |
| """Test that the application does not have a spawning parent when it is not specified""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| spawned_by = app.spawning_parent_pointer |
| assert spawned_by is None |
| |
| |
| def test_application_with_spawning_parent(): |
| """Tests that the application builder can specify a spawning |
| parent and it gets wired through to the app.""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .with_identifiers(app_id="test123", partition_key="user123") |
| .with_spawning_parent(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| spawned_by = app.spawning_parent_pointer |
| assert spawned_by is not None |
| assert spawned_by.app_id == "test123" |
| assert spawned_by.partition_key == "user123" |
| assert spawned_by.sequence_id == 5 |
| |
| |
| def test_application_does_not_allow_dunderscore_inputs(): |
| """Tests that the context is passed to the function correctly when nulled out. |
| TODO -- get this to test without instantiating an application through the builder -- |
| this is slightly overkill for a bit of code""" |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(result_action) |
| .with_entrypoint("result") |
| .with_transitions() |
| .with_state(count=0) |
| .build() |
| ) |
| with pytest.raises(ValueError, match="double underscore"): |
| app._process_inputs({"__not_allowed": ...}, app.get_next_action()) |