| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import asyncio |
| import collections |
| import datetime |
| import logging |
| import typing |
| import uuid |
| from typing import Any, Awaitable, Callable, Dict, Generator, Literal, Optional, Tuple, Union |
| |
| import pytest |
| |
| from burr.core import State |
| from burr.core.action import ( |
| DEFAULT_SCHEMA, |
| Action, |
| AsyncGenerator, |
| AsyncStreamingAction, |
| Condition, |
| Reducer, |
| Result, |
| SingleStepAction, |
| SingleStepStreamingAction, |
| StreamingAction, |
| action, |
| default, |
| expr, |
| ) |
| from burr.core.application import ( |
| PRIOR_STEP, |
| Application, |
| ApplicationBuilder, |
| ApplicationContext, |
| _adjust_single_step_output, |
| _arun_function, |
| _arun_multi_step_streaming_action, |
| _arun_single_step_action, |
| _arun_single_step_streaming_action, |
| _remap_dunder_parameters, |
| _run_function, |
| _run_multi_step_streaming_action, |
| _run_reducer, |
| _run_single_step_action, |
| _run_single_step_streaming_action, |
| _validate_reducer_writes, |
| _validate_start, |
| ) |
| from burr.core.graph import Graph, GraphBuilder, Transition |
| from burr.core.persistence import ( |
| AsyncDevNullPersister, |
| BaseStateLoader, |
| BaseStatePersister, |
| DevNullPersister, |
| PersistedStateData, |
| SQLLitePersister, |
| ) |
| from burr.core.typing import TypingSystem |
| from burr.lifecycle import ( |
| PostRunStepHook, |
| PostRunStepHookAsync, |
| PreRunStepHook, |
| PreRunStepHookAsync, |
| internal, |
| ) |
| from burr.lifecycle.base import ( |
| ExecuteMethod, |
| PostApplicationCreateHook, |
| PostApplicationExecuteCallHook, |
| PostApplicationExecuteCallHookAsync, |
| PostEndStreamHook, |
| PostStreamItemHook, |
| PostStreamItemHookAsync, |
| PreApplicationExecuteCallHook, |
| PreApplicationExecuteCallHookAsync, |
| PreStartStreamHook, |
| PreStartStreamHookAsync, |
| ) |
| from burr.lifecycle.internal import LifecycleAdapterSet |
| from burr.tracking.base import SyncTrackingClient |
| |
| |
| class PassedInAction(Action): |
| def __init__( |
| self, |
| reads: list[str], |
| writes: list[str], |
| fn: Callable[..., dict], |
| update_fn: Callable[[dict, State], State], |
| inputs: list[str], |
| ): |
| super(PassedInAction, self).__init__() |
| self._reads = reads |
| self._writes = writes |
| self._fn = fn |
| self._update_fn = update_fn |
| self._inputs = inputs |
| |
| def run(self, state: State, **run_kwargs) -> dict: |
| return self._fn(state, **run_kwargs) |
| |
| @property |
| def inputs(self) -> list[str]: |
| return self._inputs |
| |
| def update(self, result: dict, state: State) -> State: |
| return self._update_fn(result, state) |
| |
| @property |
| def reads(self) -> list[str]: |
| return self._reads |
| |
| @property |
| def writes(self) -> list[str]: |
| return self._writes |
| |
| |
| class PassedInActionAsync(PassedInAction): |
| def __init__( |
| self, |
| reads: list[str], |
| writes: list[str], |
| fn: Callable[..., Awaitable[dict]], |
| update_fn: Callable[[dict, State], State], |
| inputs: list[str], |
| ): |
| super().__init__(reads=reads, writes=writes, fn=fn, update_fn=update_fn, inputs=inputs) # type: ignore |
| |
| async def run(self, state: State, **run_kwargs) -> dict: |
| return await self._fn(state, **run_kwargs) |
| |
| |
| base_counter_action = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state: {"count": state.get("count", 0) + 1}, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| |
| base_counter_action_with_inputs = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state, additional_increment: { |
| "count": state.get("count", 0) + 1 + additional_increment |
| }, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=["additional_increment"], |
| ) |
| |
| |
| class StreamEventCaptureTracker(PreStartStreamHook, PostStreamItemHook, PostEndStreamHook): |
| def post_end_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| self.post_end_stream_calls.append((action, locals())) |
| |
| def __init__(self): |
| self.pre_start_stream_calls = [] |
| self.post_stream_item_calls = [] |
| self.post_end_stream_calls = [] |
| |
| def pre_start_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| self.pre_start_stream_calls.append((action, locals())) |
| |
| def post_stream_item( |
| self, |
| *, |
| item: Any, |
| item_index: int, |
| stream_initialize_time: datetime.datetime, |
| first_stream_item_start_time: datetime.datetime, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| self.post_stream_item_calls.append((action, locals())) |
| |
| |
| class StreamEventCaptureTrackerAsync( |
| PreStartStreamHookAsync, PostStreamItemHookAsync, PostEndStreamHook |
| ): |
| def __init__(self): |
| self.pre_start_stream_calls = [] |
| self.post_stream_item_calls = [] |
| self.post_end_stream_calls = [] |
| |
| async def pre_start_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| self.pre_start_stream_calls.append((action, locals())) |
| |
| async def post_stream_item( |
| self, |
| *, |
| item: Any, |
| item_index: int, |
| stream_initialize_time: datetime.datetime, |
| first_stream_item_start_time: datetime.datetime, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| self.post_stream_item_calls.append((action, locals())) |
| |
| def post_end_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| self.post_end_stream_calls.append((action, locals())) |
| |
| |
| class CallCaptureTracker( |
| PreRunStepHook, |
| PostRunStepHook, |
| PreApplicationExecuteCallHook, |
| PostApplicationExecuteCallHook, |
| ): |
| def __init__(self): |
| self.pre_called = [] |
| self.post_called = [] |
| self.pre_run_execute_calls = [] |
| self.post_run_execute_calls = [] |
| |
| def pre_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| inputs: Dict[str, Any], |
| **future_kwargs: Any, |
| ): |
| self.pre_called.append((action.name, locals())) |
| |
| def post_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| result: Optional[Dict[str, Any]], |
| exception: Exception, |
| **future_kwargs: Any, |
| ): |
| self.post_called.append((action.name, locals())) |
| |
| def pre_run_execute_call( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| state: "State", |
| method: ExecuteMethod, |
| **future_kwargs: Any, |
| ): |
| self.pre_run_execute_calls.append(("pre", locals())) |
| |
| def post_run_execute_call( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| state: "State", |
| method: ExecuteMethod, |
| exception: Optional[Exception], |
| **future_kwargs, |
| ): |
| self.post_run_execute_calls.append(("post", locals())) |
| |
| |
| class ExecuteMethodTrackerAsync( |
| PreApplicationExecuteCallHookAsync, PostApplicationExecuteCallHookAsync |
| ): |
| def __init__(self): |
| self.pre_run_execute_calls = [] |
| self.post_run_execute_calls = [] |
| |
| async def pre_run_execute_call( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| state: "State", |
| method: ExecuteMethod, |
| **future_kwargs: Any, |
| ): |
| self.pre_run_execute_calls.append(("pre", locals())) |
| |
| async def post_run_execute_call( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| state: "State", |
| method: ExecuteMethod, |
| exception: Optional[Exception], |
| **future_kwargs, |
| ): |
| self.post_run_execute_calls.append(("post", locals())) |
| |
| |
| class ActionTrackerAsync(PreRunStepHookAsync, PostRunStepHookAsync): |
| def __init__(self): |
| self.pre_called = [] |
| self.post_called = [] |
| |
| async def pre_run_step(self, *, action: Action, **future_kwargs): |
| await asyncio.sleep(0.0001) |
| self.pre_called.append((action.name, future_kwargs)) |
| |
| async def post_run_step(self, *, action: Action, **future_kwargs): |
| await asyncio.sleep(0.0001) |
| self.post_called.append((action.name, future_kwargs)) |
| |
| |
| async def _counter_update_async(state: State, additional_increment: int = 0) -> dict: |
| await asyncio.sleep(0.0001) # just so we can make this *truly* async |
| # does not matter, but more accurately simulates an async function |
| return {"count": state.get("count", 0) + 1 + additional_increment} |
| |
| |
| base_counter_action_async = PassedInActionAsync( |
| reads=["count"], |
| writes=["count"], |
| fn=_counter_update_async, |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| |
| base_counter_action_with_inputs_async = PassedInActionAsync( |
| reads=["count"], |
| writes=["count"], |
| fn=lambda state, additional_increment: _counter_update_async( |
| state, additional_increment=additional_increment |
| ), |
| update_fn=lambda result, state: state.update(**result), |
| inputs=["additional_increment"], |
| ) |
| |
| |
| class BrokenStepException(Exception): |
| pass |
| |
| |
| base_broken_action = PassedInAction( |
| reads=[], |
| writes=[], |
| fn=lambda x: exec("raise(BrokenStepException(x))"), |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| base_broken_action_async = PassedInActionAsync( |
| reads=[], |
| writes=[], |
| fn=lambda x: exec("raise(BrokenStepException(x))"), |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| |
| async def incorrect(x): |
| return "not a dict" |
| |
| |
| base_action_incorrect_result_type = PassedInAction( |
| reads=[], |
| writes=[], |
| fn=lambda x: "not a dict", |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| base_action_incorrect_result_type_async = PassedInActionAsync( |
| reads=[], |
| writes=[], |
| fn=incorrect, |
| update_fn=lambda result, state: state, |
| inputs=[], |
| ) |
| |
| |
| def test__run_function(): |
| """Tests that we can run a function""" |
| action = base_counter_action |
| state = State({}) |
| result = _run_function(action, state, inputs={}, name=action.name) |
| assert result == {"count": 1} |
| |
| |
| def test__run_function_with_inputs(): |
| """Tests that we can run a function""" |
| action = base_counter_action_with_inputs |
| state = State({}) |
| result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name) |
| assert result == {"count": 2} |
| |
| |
| def test__run_function_cant_run_async(): |
| """Tests that we can't run an async function""" |
| action = base_counter_action_async |
| state = State({}) |
| with pytest.raises(ValueError, match="async"): |
| _run_function(action, state, inputs={}, name=action.name) |
| |
| |
| def test__run_function_incorrect_result_type(): |
| """Tests that we can run an async function""" |
| action = base_action_incorrect_result_type |
| state = State({}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| _run_function(action, state, inputs={}, name=action.name) |
| |
| |
| def test__run_reducer_modifies_state(): |
| """Tests that we can run a reducer and it behaves as expected""" |
| reducer = PassedInAction( |
| reads=["count"], |
| writes=["count"], |
| fn=..., |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| state = State({"count": 0}) |
| state = _run_reducer(reducer, state, {"count": 1}, "reducer") |
| assert state["count"] == 1 |
| |
| |
| def test__run_reducer_deletes_state(): |
| """Tests that we can run a reducer that deletes an item from state""" |
| reducer = PassedInAction( |
| reads=["count"], |
| writes=[], # TODO -- figure out how we can better know that it deletes items...ß |
| fn=..., |
| update_fn=lambda result, state: state.wipe(delete=["count"]), |
| inputs=[], |
| ) |
| state = State({"count": 0}) |
| state = _run_reducer(reducer, state, {}, "deletion_reducer") |
| assert "count" not in state |
| |
| |
| def test__validate_reducer_writes_with_state_keys_returning_list(): |
| """Tests that _validate_reducer_writes works when state.keys() returns a list. |
| |
| This is a regression test for a bug where state.keys() could return a list |
| instead of a set, causing a TypeError when trying to do set subtraction. |
| """ |
| # Create a reducer with some expected writes |
| reducer = PassedInAction( |
| reads=["input"], |
| writes=["output", "result"], |
| fn=..., |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| |
| # Create a state that has all the required writes |
| state = State({"input": 1, "output": 2, "result": 3}) |
| |
| # This should not raise a TypeError even if state.keys() returns a list |
| # (which was the original bug) |
| _validate_reducer_writes(reducer, state, "test_action") |
| |
| |
| def test__validate_reducer_writes_raises_on_missing_keys(): |
| """Tests that _validate_reducer_writes raises ValueError when required keys are missing.""" |
| # Create a reducer with some expected writes |
| reducer = PassedInAction( |
| reads=["input"], |
| writes=["output", "result", "missing_key"], |
| fn=..., |
| update_fn=lambda result, state: state.update(**result), |
| inputs=[], |
| ) |
| |
| # Create a state that is missing some required writes |
| state = State({"input": 1, "output": 2, "result": 3}) |
| |
| # This should raise a ValueError for missing "missing_key" |
| with pytest.raises(ValueError, match="missing_key"): |
| _validate_reducer_writes(reducer, state, "test_action") |
| |
| |
| async def test__arun_function(): |
| """Tests that we can run an async function""" |
| action = base_counter_action_async |
| state = State({}) |
| result = await _arun_function(action, state, inputs={}, name=action.name) |
| assert result == {"count": 1} |
| |
| |
| async def test__arun_function_incorrect_result_type(): |
| """Tests that we can run an async function""" |
| action = base_action_incorrect_result_type_async |
| state = State({}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| await _arun_function(action, state, inputs={}, name=action.name) |
| |
| |
| async def test__arun_function_with_inputs(): |
| """Tests that we can run an async function""" |
| action = base_counter_action_with_inputs_async |
| state = State({}) |
| result = await _arun_function( |
| action, state, inputs={"additional_increment": 1}, name=action.name |
| ) |
| assert result == {"count": 2} |
| |
| |
| def test_run_reducer_errors_missing_writes(): |
| class BrokenReducer(Reducer): |
| def update(self, result: dict, state: State) -> State: |
| return state.update(present_value=1) |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| reducer = BrokenReducer() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| _run_reducer(reducer, state, {}, "broken_reducer") |
| |
| |
| def test_run_single_step_action_errors_missing_writes(): |
| class BrokenAction(SingleStepAction): |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| _run_single_step_action(action, state, inputs={}) |
| |
| |
| async def test_arun_single_step_action_errors_missing_writes(): |
| class BrokenAction(SingleStepAction): |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| await asyncio.sleep(0.0001) # just so we can make this *truly* async |
| return {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| await _arun_single_step_action(action, state, inputs={}) |
| |
| |
| def test_run_single_step_streaming_action_errors_missing_write(): |
| class BrokenAction(SingleStepStreamingAction): |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| yield {}, None |
| yield {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| gen = _run_single_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| ) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| async def test_run_single_step_streaming_action_errors_missing_write_async(): |
| class BrokenAction(SingleStepStreamingAction): |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| yield {}, None |
| yield {"present_value": 1}, state.update(present_value=1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| gen = _arun_single_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app_id", |
| partition_key="partition_key", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| [result async for result in gen] # exhaust the generator |
| |
| |
| def test_run_multi_step_streaming_action_errors_missing_write(): |
| class BrokenAction(StreamingAction): |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: |
| yield {} |
| yield {"present_value": 1} |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(present_value=1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["missing_value", "present_value"] |
| |
| action = BrokenAction() |
| state = State() |
| with pytest.raises(ValueError, match="missing_value"): |
| gen = _run_multi_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| ) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| class SingleStepCounter(SingleStepAction): |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| result = {"count": state["count"] + 1 + sum([0] + list(run_kwargs.values()))} |
| return result, state.update(**result).append(tracker=result["count"]) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepCounterWithInputs(SingleStepCounter): |
| @property |
| def inputs(self) -> list[str]: |
| return ["additional_increment"] |
| |
| |
| class SingleStepActionIncorrectResultType(SingleStepAction): |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return "not a dict", state |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| class SingleStepActionIncorrectResultTypeAsync(SingleStepActionIncorrectResultType): |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return "not a dict", state |
| |
| |
| class SingleStepCounterAsync(SingleStepCounter): |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| await asyncio.sleep(0.0001) # just so we can make this *truly* async |
| return super(SingleStepCounterAsync, self).run_and_update(state, **run_kwargs) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepCounterWithInputsAsync(SingleStepCounterAsync): |
| @property |
| def inputs(self) -> list[str]: |
| return ["additional_increment"] |
| |
| |
| class StreamingCounter(StreamingAction): |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: |
| if "steps_per_count" in run_kwargs: |
| steps_per_count = run_kwargs["granularity"] |
| else: |
| steps_per_count = 10 |
| count = state["count"] |
| for i in range(steps_per_count): |
| yield {"count": count + ((i + 1) / 10)} |
| yield {"count": count + 1} |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| class AsyncStreamingCounter(AsyncStreamingAction): |
| async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: |
| if "steps_per_count" in run_kwargs: |
| steps_per_count = run_kwargs["granularity"] |
| else: |
| steps_per_count = 10 |
| count = state["count"] |
| for i in range(steps_per_count): |
| await asyncio.sleep(0.01) |
| yield {"count": count + (i + 1) / 10} |
| await asyncio.sleep(0.01) |
| yield {"count": count + 1} |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| class SingleStepStreamingCounter(SingleStepStreamingAction): |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| steps_per_count = run_kwargs.get("granularity", 10) |
| count = state["count"] |
| for i in range(steps_per_count): |
| yield {"count": count + ((i + 1) / 10)}, None |
| yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepStreamingCounterAsync(SingleStepStreamingAction): |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| steps_per_count = run_kwargs.get("granularity", 10) |
| count = state["count"] |
| for i in range(steps_per_count): |
| await asyncio.sleep(0.01) |
| yield {"count": count + ((i + 1) / 10)}, None |
| await asyncio.sleep(0.01) |
| yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class StreamingActionIncorrectResultType(StreamingAction): |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]: |
| yield {} |
| yield "not a dict" |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state |
| |
| |
| class StreamingActionIncorrectResultTypeAsync(AsyncStreamingAction): |
| async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: |
| yield {} |
| yield "not a dict" |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state |
| |
| |
| class StreamingSingleStepActionIncorrectResultType(SingleStepStreamingAction): |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| yield {}, State |
| yield "not a dict", state |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| class StreamingSingleStepActionIncorrectResultTypeAsync(SingleStepStreamingAction): |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> typing.AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| yield {}, None |
| yield "not a dict", state |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| base_single_step_counter = SingleStepCounter() |
| base_single_step_counter_async = SingleStepCounterAsync() |
| base_single_step_counter_with_inputs = SingleStepCounterWithInputs() |
| base_single_step_counter_with_inputs_async = SingleStepCounterWithInputsAsync() |
| |
| base_streaming_counter = StreamingCounter() |
| base_streaming_single_step_counter = SingleStepStreamingCounter() |
| |
| base_streaming_counter_async = AsyncStreamingCounter() |
| base_streaming_single_step_counter_async = SingleStepStreamingCounterAsync() |
| |
| base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType() |
| base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync() |
| |
| |
| def test__run_single_step_action(): |
| action = base_single_step_counter.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = _run_single_step_action(action, state, inputs={}) |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| result, state = _run_single_step_action(action, state, inputs={}) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [1, 2]} |
| |
| |
| def test__run_single_step_action_incorrect_result_type(): |
| action = base_single_step_action_incorrect_result_type.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| _run_single_step_action(action, state, inputs={}) |
| |
| |
| async def test__arun_single_step_action_incorrect_result_type(): |
| action = base_single_step_action_incorrect_result_type_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| await _arun_single_step_action(action, state, inputs={}) |
| |
| |
| def test__run_single_step_action_with_inputs(): |
| action = base_single_step_counter_with_inputs.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| result, state = _run_single_step_action(action, state, inputs={"additional_increment": 1}) |
| assert result == {"count": 4} |
| assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} |
| |
| |
| async def test__arun_single_step_action(): |
| action = base_single_step_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = await _arun_single_step_action(action, state, inputs={}) |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| result, state = await _arun_single_step_action(action, state, inputs={}) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [1, 2]} |
| |
| |
| async def test__arun_single_step_action_with_inputs(): |
| action = base_single_step_counter_with_inputs_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| result, state = await _arun_single_step_action( |
| action, state, inputs={"additional_increment": 1} |
| ) |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| result, state = await _arun_single_step_action( |
| action, state, inputs={"additional_increment": 1} |
| ) |
| assert result == {"count": 4} |
| assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} |
| |
| |
| class SingleStepActionWithDeletion(SingleStepAction): |
| def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return {}, state.wipe(delete=["to_delete"]) |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| |
| def test__run_single_step_action_deletes_state(): |
| action = SingleStepActionWithDeletion() |
| state = State({"to_delete": 0}) |
| result, state = _run_single_step_action(action, state, inputs={}) |
| assert "to_delete" not in state |
| |
| |
| def test__run_multistep_streaming_action(): |
| action = base_streaming_counter.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_multi_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| ) |
| last_result = -1 |
| result = None |
| for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit floating poit comparison problems |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| def test__run_multistep_streaming_action_callbacks(): |
| class TrackingCallback(PostStreamItemHook): |
| def __init__(self): |
| self.items = [] |
| |
| def post_stream_item(self, item: Any, **future_kwargs: Any): |
| self.items.append(item) |
| |
| hook = TrackingCallback() |
| |
| action = base_streaming_counter.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_multi_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| lifecycle_adapters=LifecycleAdapterSet(hook), |
| ) |
| last_result = -1 |
| result = None |
| for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit floating poit comparison problems |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| assert len(hook.items) == 10 # one for each streaming callback |
| |
| |
| async def test__run_multistep_streaming_action_async(): |
| action = base_streaming_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_multi_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app_id", |
| partition_key="partition_key", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| last_result = -1 |
| result = None |
| async for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit floating poit comparison problems |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| async def test__run_multistep_streaming_action_async_callbacks(): |
| class TrackingCallback(PostStreamItemHookAsync): |
| def __init__(self): |
| self.items = [] |
| |
| async def post_stream_item(self, item: Any, **future_kwargs: Any): |
| self.items.append(item) |
| |
| hook = TrackingCallback() |
| action = base_streaming_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_multi_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app_id", |
| partition_key="partition_key", |
| lifecycle_adapters=LifecycleAdapterSet(hook), |
| ) |
| last_result = -1 |
| result = None |
| async for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit floating poit comparison problems |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| assert len(hook.items) == 10 # one for each streaming callback |
| |
| |
| def test__run_streaming_action_incorrect_result_type(): |
| action = StreamingActionIncorrectResultType() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _run_multi_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| ) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| async def test__run_streaming_action_incorrect_result_type_async(): |
| action = StreamingActionIncorrectResultTypeAsync() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _arun_multi_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app_id", |
| partition_key="partition_key", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| async for _ in gen: |
| pass |
| |
| |
| def test__run_single_step_streaming_action_incorrect_result_type(): |
| action = StreamingSingleStepActionIncorrectResultType() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _run_single_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| ) |
| collections.deque(gen, maxlen=0) # exhaust the generator |
| |
| |
| async def test__run_single_step_streaming_action_incorrect_result_type_async(): |
| action = StreamingSingleStepActionIncorrectResultTypeAsync() |
| state = State() |
| with pytest.raises(ValueError, match="returned a non-dict"): |
| gen = _arun_single_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| _ = [item async for item in gen] |
| |
| |
| def test__run_single_step_streaming_action(): |
| action = base_streaming_single_step_counter.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_single_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| ) |
| last_result = -1 |
| result, state = None, None |
| for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit comparison issues |
| # This is because we get to the last one, which is the final result |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| def test__run_single_step_streaming_action_calls_callbacks(): |
| action = base_streaming_single_step_counter.with_name("counter") |
| |
| class TrackingCallback(PostStreamItemHook): |
| def __init__(self): |
| self.items = [] |
| |
| def post_stream_item(self, item: Any, **future_kwargs: Any): |
| self.items.append(item) |
| |
| hook = TrackingCallback() |
| |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_single_step_streaming_action( |
| action, |
| state, |
| inputs={}, |
| sequence_id=0, |
| partition_key="partition_key", |
| app_id="app_id", |
| lifecycle_adapters=LifecycleAdapterSet(hook), |
| ) |
| last_result = -1 |
| result, state = None, None |
| for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit comparison issues |
| # This is because we get to the last one, which is the final result |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| assert len(hook.items) == 10 # one for each streaming callback |
| |
| |
| async def test__run_single_step_streaming_action_async(): |
| async_action = base_streaming_single_step_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_single_step_streaming_action( |
| action=async_action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app_id", |
| partition_key="partition_key", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| last_result = -1 |
| result, state = None, None |
| async for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit comparison issues |
| # This is because we get to the last one, which is the final result |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| |
| |
| async def test__run_single_step_streaming_action_async_callbacks(): |
| class TrackingCallback(PostStreamItemHookAsync): |
| def __init__(self): |
| self.items = [] |
| |
| async def post_stream_item(self, item: Any, **future_kwargs: Any): |
| self.items.append(item) |
| |
| hook = TrackingCallback() |
| |
| async_action = base_streaming_single_step_counter_async.with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_single_step_streaming_action( |
| action=async_action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app_id", |
| partition_key="partition_key", |
| lifecycle_adapters=LifecycleAdapterSet(hook), |
| ) |
| last_result = -1 |
| result, state = None, None |
| async for result, state in generator: |
| if last_result < 1: |
| # Otherwise you hit comparison issues |
| # This is because we get to the last one, which is the final result |
| assert result["count"] > last_result |
| last_result = result["count"] |
| assert result == {"count": 1} |
| assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} |
| assert len(hook.items) == 10 # one for each streaming callback |
| |
| |
| class SingleStepStreamingCounterWithException(SingleStepStreamingAction): |
| """Yields intermediate items, raises, then yields final state in finally block.""" |
| |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| count = state["count"] |
| try: |
| for i in range(3): |
| yield {"count": count + ((i + 1) / 10)}, None |
| raise RuntimeError("simulated failure") |
| finally: |
| yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepStreamingCounterWithExceptionNoState(SingleStepStreamingAction): |
| """Raises without ever yielding a final state update.""" |
| |
| def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> Generator[Tuple[dict, Optional[State]], None, None]: |
| count = state["count"] |
| for i in range(3): |
| yield {"count": count + ((i + 1) / 10)}, None |
| raise RuntimeError("simulated failure with no state") |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepStreamingCounterWithExceptionAsync(SingleStepStreamingAction): |
| """Async variant: yields intermediate items, raises, then yields final state in finally.""" |
| |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| count = state["count"] |
| try: |
| for i in range(3): |
| yield {"count": count + ((i + 1) / 10)}, None |
| raise RuntimeError("simulated failure") |
| finally: |
| yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class SingleStepStreamingCounterWithExceptionNoStateAsync(SingleStepStreamingAction): |
| """Async variant: raises without ever yielding a final state update.""" |
| |
| async def stream_run_and_update( |
| self, state: State, **run_kwargs |
| ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: |
| count = state["count"] |
| for i in range(3): |
| yield {"count": count + ((i + 1) / 10)}, None |
| raise RuntimeError("simulated failure with no state") |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| |
| class MultiStepStreamingCounterWithException(StreamingAction): |
| """Yields intermediate items, raises, then yields final result in finally block.""" |
| |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: |
| count = state["count"] |
| try: |
| for i in range(3): |
| yield {"count": count + ((i + 1) / 10)} |
| raise RuntimeError("simulated failure") |
| finally: |
| yield {"count": count + 1} |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| class MultiStepStreamingCounterWithExceptionNoResult(StreamingAction): |
| """Raises without ever yielding any item.""" |
| |
| def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: |
| raise RuntimeError("simulated failure with no result") |
| yield # make this a generator function |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction): |
| """Async variant: yields intermediate items, raises, then yields final result in finally.""" |
| |
| async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: |
| count = state["count"] |
| try: |
| for i in range(3): |
| yield {"count": count + ((i + 1) / 10)} |
| raise RuntimeError("simulated failure") |
| finally: |
| yield {"count": count + 1} |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| class MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction): |
| """Async variant: raises without ever yielding any item.""" |
| |
| async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: |
| raise RuntimeError("simulated failure with no result") |
| yield # make this an async generator |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["count"] |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["count", "tracker"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result).append(tracker=result["count"]) |
| |
| |
| def test__run_single_step_streaming_action_graceful_exception(): |
| """When the generator raises but yields a final state in finally, stream completes gracefully.""" |
| action = SingleStepStreamingCounterWithException().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_single_step_streaming_action( |
| action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" |
| ) |
| results = list(generator) |
| intermediate = [(r, s) for r, s in results if s is None] |
| final = [(r, s) for r, s in results if s is not None] |
| assert len(intermediate) == 3 |
| assert len(final) == 1 |
| assert final[0][0] == {"count": 1} |
| assert final[0][1].subset("count", "tracker").get_all() == { |
| "count": 1, |
| "tracker": [1], |
| } |
| |
| |
| def test__run_single_step_streaming_action_exception_propagates(): |
| """When the generator raises without yielding a final state, exception propagates.""" |
| action = SingleStepStreamingCounterWithExceptionNoState().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_single_step_streaming_action( |
| action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" |
| ) |
| with pytest.raises(RuntimeError, match="simulated failure with no state"): |
| list(generator) |
| |
| |
| async def test__run_single_step_streaming_action_graceful_exception_async(): |
| """Async: when the generator raises but yields a final state in finally, stream completes.""" |
| action = SingleStepStreamingCounterWithExceptionAsync().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_single_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app", |
| partition_key="pk", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| results = [] |
| async for item in generator: |
| results.append(item) |
| intermediate = [(r, s) for r, s in results if s is None] |
| final = [(r, s) for r, s in results if s is not None] |
| assert len(intermediate) == 3 |
| assert len(final) == 1 |
| assert final[0][0] == {"count": 1} |
| assert final[0][1].subset("count", "tracker").get_all() == { |
| "count": 1, |
| "tracker": [1], |
| } |
| |
| |
| async def test__run_single_step_streaming_action_exception_propagates_async(): |
| """Async: when the generator raises without yielding a final state, exception propagates.""" |
| action = SingleStepStreamingCounterWithExceptionNoStateAsync().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_single_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app", |
| partition_key="pk", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| with pytest.raises(RuntimeError, match="simulated failure with no state"): |
| async for _ in generator: |
| pass |
| |
| |
| def test__run_multi_step_streaming_action_graceful_exception(): |
| """When the generator raises but yields a final result in finally, stream completes.""" |
| action = MultiStepStreamingCounterWithException().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_multi_step_streaming_action( |
| action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" |
| ) |
| results = list(generator) |
| intermediate = [(r, s) for r, s in results if s is None] |
| final = [(r, s) for r, s in results if s is not None] |
| assert len(intermediate) == 3 |
| assert len(final) == 1 |
| assert final[0][0] == {"count": 1} |
| assert final[0][1].subset("count", "tracker").get_all() == { |
| "count": 1, |
| "tracker": [1], |
| } |
| |
| |
| def test__run_multi_step_streaming_action_exception_propagates(): |
| """When the generator raises without yielding any result, exception propagates.""" |
| action = MultiStepStreamingCounterWithExceptionNoResult().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _run_multi_step_streaming_action( |
| action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" |
| ) |
| with pytest.raises(RuntimeError, match="simulated failure with no result"): |
| list(generator) |
| |
| |
| async def test__run_multi_step_streaming_action_graceful_exception_async(): |
| """Async: when the generator raises but yields a final result in finally, stream completes.""" |
| action = MultiStepStreamingCounterWithExceptionAsync().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_multi_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app", |
| partition_key="pk", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| results = [] |
| async for item in generator: |
| results.append(item) |
| intermediate = [(r, s) for r, s in results if s is None] |
| final = [(r, s) for r, s in results if s is not None] |
| assert len(intermediate) == 3 |
| assert len(final) == 1 |
| assert final[0][0] == {"count": 1} |
| assert final[0][1].subset("count", "tracker").get_all() == { |
| "count": 1, |
| "tracker": [1], |
| } |
| |
| |
| async def test__run_multi_step_streaming_action_exception_propagates_async(): |
| """Async: when the generator raises without yielding any result, exception propagates.""" |
| action = MultiStepStreamingCounterWithExceptionNoResultAsync().with_name("counter") |
| state = State({"count": 0, "tracker": []}) |
| generator = _arun_multi_step_streaming_action( |
| action=action, |
| state=state, |
| inputs={}, |
| sequence_id=0, |
| app_id="app", |
| partition_key="pk", |
| lifecycle_adapters=LifecycleAdapterSet(), |
| ) |
| with pytest.raises(RuntimeError, match="simulated failure with no result"): |
| async for _ in generator: |
| pass |
| |
| |
| class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): |
| async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: |
| return {}, state.wipe(delete=["to_delete"]) |
| |
| |
| async def test__arun_single_step_action_deletes_state(): |
| action = SingleStepActionWithDeletionAsync() |
| state = State({"to_delete": 0}) |
| result, state = await _arun_single_step_action(action, state, inputs={}) |
| assert "to_delete" not in state |
| |
| |
| def test_app_step(): |
| """Tests that we can run a step in an app""" |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = app.step() |
| assert app.sequence_id == 1 |
| assert action.name == "counter" |
| assert result == {"count": 1} |
| assert state[PRIOR_STEP] == "counter" # internal contract, not part of the public API |
| |
| |
| def test_app_step_with_inputs(): |
| """Tests that we can run a step in an app""" |
| counter_action = base_single_step_counter_with_inputs.with_name("counter") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = app.step(inputs={"additional_increment": 1}) |
| assert action.name == "counter" |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| |
| |
| def test_app_step_with_inputs_missing(): |
| """Tests that we can run a step in an app""" |
| counter_action = base_single_step_counter_with_inputs.with_name("counter") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| with pytest.raises(ValueError, match="missing required inputs"): |
| app.step(inputs={}) |
| |
| |
| def test_app_step_broken(caplog): |
| """Tests that we can run a step in an app""" |
| broken_action = base_broken_action.with_name("broken_action_unique_name") |
| app = Application( |
| state=State({}), |
| entrypoint="broken_action_unique_name", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[broken_action], |
| transitions=[Transition(broken_action, broken_action, default)], |
| ), |
| ) |
| with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now |
| with pytest.raises(BrokenStepException): |
| app.step() |
| assert "broken_action_unique_name" in caplog.text |
| |
| |
| def test_app_step_done(): |
| """Tests that when we cannot run a step, we return None""" |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[], |
| ), |
| ) |
| app.step() |
| assert app.step() is None |
| |
| |
| async def test_app_astep(): |
| """Tests that we can run an async step in an app""" |
| counter_action = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = await app.astep() |
| assert app.sequence_id == 1 |
| assert action.name == "counter_async" |
| assert result == {"count": 1} |
| assert state[PRIOR_STEP] == "counter_async" # internal contract, not part of the public API |
| |
| |
| def test_app_step_context(): |
| APP_ID = str(uuid.uuid4()) |
| PARTITION_KEY = str(uuid.uuid4()) |
| |
| @action(reads=[], writes=[]) |
| def test_action(state: State, __context: ApplicationContext) -> State: |
| assert __context.sequence_id == 0 |
| assert __context.partition_key == PARTITION_KEY |
| assert __context.app_id == APP_ID |
| assert __context.action_name == "test_action" |
| return state |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions(test_action) |
| .with_entrypoint("test_action") |
| .with_transitions() |
| .with_identifiers( |
| app_id=APP_ID, |
| partition_key=PARTITION_KEY, |
| ) |
| .build() |
| ) |
| app.step() |
| |
| |
| async def test_app_astep_context(): |
| """Tests that app.astep correctly passes context.""" |
| APP_ID = str(uuid.uuid4()) |
| PARTITION_KEY = str(uuid.uuid4()) |
| |
| @action(reads=[], writes=[]) |
| def test_action(state: State, __context: ApplicationContext) -> State: |
| assert __context.sequence_id == 0 |
| assert __context.partition_key == PARTITION_KEY |
| assert __context.app_id == APP_ID |
| assert __context.action_name == "test_action" |
| return state |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions(test_action) |
| .with_entrypoint("test_action") |
| .with_transitions() |
| .with_identifiers( |
| app_id=APP_ID, |
| partition_key=PARTITION_KEY, |
| ) |
| .build() |
| ) |
| await app.astep() |
| |
| |
| async def test_app_astep_with_inputs(): |
| """Tests that we can run an async step in an app""" |
| counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result, state = await app.astep(inputs={"additional_increment": 1}) |
| assert action.name == "counter_async" |
| assert result == {"count": 2} |
| assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [2]} |
| |
| |
| async def test_app_astep_with_inputs_missing(): |
| """Tests that we can run an async step in an app""" |
| counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async") |
| app = Application( |
| state=State({"count": 0, "tracker": []}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| with pytest.raises(ValueError, match="missing required inputs"): |
| await app.astep(inputs={}) |
| |
| |
| async def test_app_astep_broken(caplog): |
| """Tests that we can run a step in an app""" |
| broken_action = base_broken_action_async.with_name("broken_action_unique_name") |
| app = Application( |
| state=State({}), |
| entrypoint="broken_action_unique_name", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[broken_action], |
| transitions=[Transition(broken_action, broken_action, default)], |
| ), |
| ) |
| with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now |
| with pytest.raises(BrokenStepException): |
| await app.astep() |
| assert "broken_action_unique_name" in caplog.text |
| |
| |
| async def test_app_astep_done(): |
| """Tests that when we cannot run a step, we return None""" |
| counter_action = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[], |
| ), |
| ) |
| await app.astep() |
| assert await app.astep() is None |
| |
| |
| # internal API |
| def test_app_many_steps(): |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result = None, None |
| for i in range(100): |
| action, result, state = app.step() |
| assert action.name == "counter" |
| assert result == {"count": 100} |
| |
| |
| async def test_app_many_a_steps(): |
| counter_action = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_async", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| action, result = None, None |
| for i in range(100): |
| action, result, state = await app.astep() |
| assert action.name == "counter_async" |
| assert result == {"count": 100} |
| |
| |
| def test_iterate(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| res = [] |
| gen = app.iterate(halt_after=["result"]) |
| counter = 0 |
| try: |
| while True: |
| action, result, state = next(gen) |
| if action.name == "counter": |
| assert state["count"] == counter + 1 |
| assert result["count"] == state["count"] |
| counter = result["count"] |
| else: |
| res.append(result) |
| assert state["count"] == 10 |
| assert result["count"] == 10 |
| except StopIteration as e: |
| stop_iteration_error = e |
| generator_result = stop_iteration_error.value |
| action, result, state = generator_result |
| assert state["count"] == 10 |
| assert result["count"] == 10 |
| assert app.sequence_id == 11 |
| |
| |
| def test_iterate_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 2")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.iterate( |
| halt_after=["result"], inputs={"additional_increment": 10} |
| ) # make it go quicly to the end |
| while True: |
| try: |
| action, result, state = next(gen) |
| except StopIteration as e: |
| a, r, s = e.value |
| assert r["count"] == 11 # 1 + 10, for the first one |
| break |
| |
| |
| async def test_aiterate(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.aiterate(halt_after=["result"]) |
| assert app.sequence_id == 0 |
| counter = 0 |
| # Note that we use an async-for loop cause the API is different, this doesn't |
| # return anything (async generators are not allowed to). |
| async for action_, result, state in gen: |
| print("si", app.sequence_id, action_.name, state) |
| if action_.name == "counter": |
| assert state["count"] == result["count"] == counter + 1 |
| counter = result["count"] |
| else: |
| assert state["count"] == result["count"] == 10 |
| assert app.sequence_id == 11 |
| |
| |
| async def test_aiterate_halt_before(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.aiterate(halt_before=["result"]) |
| counter = 0 |
| # Note that we use an async-for loop cause the API is different, this doesn't |
| # return anything (async generators are not allowed to). |
| async for action_, result, state in gen: |
| if action_.name == "counter": |
| assert state["count"] == counter + 1 |
| counter = result["count"] |
| else: |
| assert result is None |
| assert state["count"] == 10 |
| |
| |
| async def test_app_aiterate_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| gen = app.aiterate(halt_after=["result"], inputs={"additional_increment": 10}) |
| async for action_, result, state in gen: |
| if action_.name == "counter": |
| assert result["count"] == state["count"] == 11 |
| else: |
| assert state["count"] == result["count"] == 11 |
| |
| |
| def test_run(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action, result, state = app.run(halt_after=["result"]) |
| assert state["count"] == 10 |
| assert result["count"] == 10 |
| |
| |
| def test_run_halt_before(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action, result, state = app.run(halt_before=["result"]) |
| assert state["count"] == 10 |
| assert result is None |
| assert action.name == "result" |
| |
| |
| def test_run_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 10}) |
| assert action_.name == "result" |
| assert state["count"] == result["count"] == 11 |
| |
| |
| def test_run_with_inputs_multiple_actions(): |
| """Tests that inputs aren't popped off and are passed through to multiple actions.""" |
| result_action = Result("count").with_name("result") |
| counter_action1 = base_counter_action_with_inputs.with_name("counter1") |
| counter_action2 = base_counter_action_with_inputs.with_name("counter2") |
| app = Application( |
| state=State({}), |
| entrypoint="counter1", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action1, counter_action2, result_action], |
| transitions=[ |
| Transition(counter_action1, counter_action1, Condition.expr("count < 10")), |
| Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), |
| Transition(counter_action2, counter_action2, Condition.expr("count < 20")), |
| Transition(counter_action2, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = app.run(halt_after=["result"], inputs={"additional_increment": 8}) |
| assert action_.name == "result" |
| assert state["count"] == result["count"] == 27 |
| assert state["__SEQUENCE_ID"] == 4 |
| |
| |
| async def test_arun(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun(halt_after=["result"]) |
| assert state["count"] == result["count"] == 10 |
| assert action_.name == "result" |
| |
| |
| async def test_arun_halt_before(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun(halt_before=["result"]) |
| assert state["count"] == 10 |
| assert result is None |
| assert action_.name == "result" |
| |
| |
| async def test_arun_with_inputs(): |
| result_action = Result("count").with_name("result") |
| counter_action = base_counter_action_with_inputs_async.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, counter_action, Condition.expr("count < 10")), |
| Transition(counter_action, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun( |
| halt_after=["result"], inputs={"additional_increment": 10} |
| ) |
| assert state["count"] == result["count"] == 11 |
| assert action_.name == "result" |
| |
| |
| async def test_arun_with_inputs_multiple_actions(): |
| result_action = Result("count").with_name("result") |
| counter_action1 = base_counter_action_with_inputs_async.with_name("counter1") |
| counter_action2 = base_counter_action_with_inputs_async.with_name("counter2") |
| app = Application( |
| state=State({}), |
| entrypoint="counter1", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action1, counter_action2, result_action], |
| transitions=[ |
| Transition(counter_action1, counter_action1, Condition.expr("count < 10")), |
| Transition(counter_action1, counter_action2, Condition.expr("count >= 10")), |
| Transition(counter_action2, counter_action2, Condition.expr("count < 20")), |
| Transition(counter_action2, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun( |
| halt_after=["result"], inputs={"additional_increment": 8} |
| ) |
| assert state["count"] == result["count"] == 27 |
| assert action_.name == "result" |
| assert state["__SEQUENCE_ID"] == 4 |
| |
| |
| async def test_app_a_run_async_and_sync(): |
| result_action = Result("count").with_name("result") |
| counter_action_sync = base_counter_action_async.with_name("counter_sync") |
| counter_action_async = base_counter_action_async.with_name("counter_async") |
| app = Application( |
| state=State({}), |
| entrypoint="counter_sync", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action_sync, counter_action_async, result_action], |
| transitions=[ |
| Transition( |
| counter_action_sync, |
| counter_action_async, |
| Condition.expr("count < 20"), |
| ), |
| Transition(counter_action_async, counter_action_sync, default), |
| Transition(counter_action_sync, result_action, default), |
| ], |
| ), |
| ) |
| action_, result, state = await app.arun(halt_after=["result"]) |
| assert state["count"] > 20 |
| assert result["count"] > 20 |
| |
| |
| def test_stream_result_halt_after_unique_ordered_sequence_id(): |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTracker() |
| counter_action = base_streaming_counter.with_name("counter") |
| counter_action_2 = base_streaming_counter.with_name("counter_2") |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = app.stream_result(halt_after=["counter_2"]) |
| results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| # One post call/one pre-call, as we call stream_result once |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| # One call for streaming |
| assert len(stream_event_tracker.pre_start_stream_calls) == 1 |
| # 10 streaming items |
| assert len(stream_event_tracker.post_stream_item_calls) == 10 |
| # One call for streaming |
| assert len(stream_event_tracker.post_end_stream_calls) == 1 |
| |
| |
| async def test_astream_result_halt_after_unique_ordered_sequence_id(): |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTrackerAsync() |
| stream_event_tracker_sync = StreamEventCaptureTracker() |
| counter_action = base_streaming_counter_async.with_name("counter") |
| counter_action_2 = base_streaming_counter_async.with_name("counter_2") |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet( |
| action_tracker, stream_event_tracker, stream_event_tracker_sync |
| ), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_async_container = await app.astream_result(halt_after=["counter_2"]) |
| results = [ |
| item async for item in streaming_async_container |
| ] # this should just have the intermediate results |
| # results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = await streaming_async_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| |
| # One call for streaming |
| assert ( |
| len(stream_event_tracker.pre_start_stream_calls) |
| == len(stream_event_tracker_sync.pre_start_stream_calls) |
| == 1 |
| ) |
| # 10 streaming items |
| assert ( |
| len(stream_event_tracker.post_stream_item_calls) |
| == len(stream_event_tracker_sync.post_stream_item_calls) |
| == 10 |
| ) |
| # One call for streaming |
| assert ( |
| len(stream_event_tracker.post_end_stream_calls) |
| == len(stream_event_tracker_sync.post_end_stream_calls) |
| == 1 |
| ) |
| |
| |
| def test_stream_result_halt_after_run_through_streaming(): |
| """Tests that we can pass through streaming results, |
| fully realize them, then get to the streaming results at the end and return the stream |
| """ |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTracker() |
| counter_action = base_streaming_single_step_counter.with_name("counter") |
| counter_action_2 = base_streaming_single_step_counter.with_name("counter_2") |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = app.stream_result(halt_after=["counter_2"]) |
| results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| # One call for streaming |
| assert len(stream_event_tracker.pre_start_stream_calls) == 1 |
| # 10 streaming items |
| assert len(stream_event_tracker.post_stream_item_calls) == 10 |
| # One call for streaming |
| assert len(stream_event_tracker.post_end_stream_calls) == 1 |
| |
| |
| @pytest.mark.parametrize("exhaust_intermediate_generators", [True, False]) |
| def test_stream_iterate(exhaust_intermediate_generators: bool): |
| """Tests that we can pass through streaming results in streaming iterate. Note that this tests two cases: |
| 1. We exhaust the intermediate generators, and then call get() to get the final result |
| 2. We don't exhaust the intermediate generators, and then call get() to get the final result |
| This ensures that the application effectively does it for us. |
| """ |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTracker() |
| counter_action = base_streaming_single_step_counter.with_name("counter") |
| counter_action_2 = base_streaming_single_step_counter.with_name("counter_2") |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| for _, streaming_container in app.stream_iterate(halt_after=["counter_2"]): |
| if exhaust_intermediate_generators: |
| results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| |
| assert len(stream_event_tracker.pre_start_stream_calls) == 2 |
| assert len(stream_event_tracker.post_end_stream_calls) == 2 |
| assert len(stream_event_tracker.post_stream_item_calls) == 20 |
| assert len(stream_event_tracker.post_stream_item_calls) == 20 |
| |
| |
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("exhaust_intermediate_generators", [True, False]) |
| async def test_astream_iterate(exhaust_intermediate_generators: bool): |
| """Tests that we can pass through streaming results in astream_iterate. Note that this tests two cases: |
| 1. We exhaust the intermediate generators, and then call get() to get the final result |
| 2. We don't exhaust the intermediate generators, and then call get() to get the final result |
| This ensures that the application effectively does it for us. |
| """ |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTracker() |
| counter_action = base_streaming_single_step_counter_async.with_name( |
| "counter" |
| ) # Use async action |
| counter_action_2 = base_streaming_single_step_counter_async.with_name( |
| "counter_2" |
| ) # Use async action |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| streaming_container = None # Define outside the loop to access later |
| async for _, streaming_container in app.astream_iterate(halt_after=["counter_2"]): |
| if exhaust_intermediate_generators: |
| results = [] |
| async for item in streaming_container: # Use async for |
| results.append(item) |
| assert len(results) == 10 |
| assert streaming_container is not None # Ensure the loop ran |
| result, state = await streaming_container.get() # Use await |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| |
| assert len(stream_event_tracker.pre_start_stream_calls) == 2 |
| assert len(stream_event_tracker.post_end_stream_calls) == 2 |
| assert len(stream_event_tracker.post_stream_item_calls) == 20 |
| |
| |
| async def test_astream_result_halt_after_run_through_streaming(): |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTrackerAsync() |
| sync_stream_event_tracker = StreamEventCaptureTracker() |
| |
| counter_action = base_streaming_single_step_counter_async.with_name("counter") |
| counter_action_2 = base_streaming_single_step_counter_async.with_name("counter_2") |
| assert counter_action.is_async() |
| assert counter_action_2.is_async() |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter", |
| adapter_set=LifecycleAdapterSet( |
| action_tracker, stream_event_tracker, sync_stream_event_tracker |
| ), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_action, counter_action_2], |
| transitions=[ |
| Transition(counter_action, counter_action_2, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = await app.astream_result(halt_after=["counter_2"]) |
| results = [ |
| item async for item in streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 10 |
| result, state = await streaming_container.get() |
| assert result["count"] == state["count"] == 2 |
| assert state["tracker"] == [1, 2] |
| assert len(action_tracker.pre_called) == 2 |
| assert len(action_tracker.post_called) == 2 |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "counter_2"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "counter_2"} |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == [ |
| 0, |
| 1, |
| ] # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| |
| # One call for streaming |
| assert ( |
| len(stream_event_tracker.pre_start_stream_calls) |
| == len(sync_stream_event_tracker.pre_start_stream_calls) |
| == 1 |
| ) |
| # 10 streaming items |
| assert ( |
| len(stream_event_tracker.post_stream_item_calls) |
| == len(sync_stream_event_tracker.post_stream_item_calls) |
| == 10 |
| ) |
| # One call for streaming |
| assert ( |
| len(stream_event_tracker.post_end_stream_calls) |
| == len(sync_stream_event_tracker.post_end_stream_calls) |
| == 1 |
| ) |
| |
| |
| def test_stream_result_halt_after_run_through_non_streaming(): |
| """Tests what happens when we have an app that runs through non-streaming |
| results before hitting a final streaming result specified by halt_after""" |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTracker() |
| counter_non_streaming = base_counter_action.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter.with_name("counter_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet(action_tracker, stream_event_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action_, streaming_container = app.stream_result(halt_after=["counter_streaming"]) |
| results = list(streaming_container) |
| assert len(results) == 10 |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| |
| # One call for streaming |
| assert len(stream_event_tracker.pre_start_stream_calls) == 1 |
| # 10 streaming items |
| assert len(stream_event_tracker.post_stream_item_calls) == 10 |
| # One call for streaming |
| assert len(stream_event_tracker.post_end_stream_calls) == 1 |
| |
| |
| async def test_astream_result_halt_after_run_through_non_streaming(): |
| action_tracker = CallCaptureTracker() |
| stream_event_tracker = StreamEventCaptureTrackerAsync() |
| stream_event_tracker_sync = StreamEventCaptureTracker() |
| counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter_async.with_name("counter_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet( |
| action_tracker, stream_event_tracker, stream_event_tracker_sync |
| ), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action_, async_streaming_container = await app.astream_result(halt_after=["counter_streaming"]) |
| results = [ |
| item async for item in async_streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 10 |
| result, state = await async_streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_streaming", |
| "counter_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| |
| # One call for streaming |
| assert ( |
| len(stream_event_tracker.pre_start_stream_calls) |
| == len(stream_event_tracker_sync.pre_start_stream_calls) |
| == 1 |
| ) |
| # 10 streaming items |
| assert ( |
| len(stream_event_tracker.post_stream_item_calls) |
| == len(stream_event_tracker_sync.post_stream_item_calls) |
| == 10 |
| ) |
| # One call for streaming |
| assert ( |
| len(stream_event_tracker.post_end_stream_calls) |
| == len(stream_event_tracker_sync.post_end_stream_calls) |
| == 1 |
| ) |
| |
| |
| def test_stream_result_halt_after_run_through_final_non_streaming(): |
| """Tests that we can pass through non-streaming results when streaming is called""" |
| action_tracker = CallCaptureTracker() |
| counter_non_streaming = base_counter_action.with_name("counter_non_streaming") |
| counter_final_non_streaming = base_counter_action.with_name("counter_final_non_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_final_non_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_final_non_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = app.stream_result(halt_after=["counter_final_non_streaming"]) |
| results = list(streaming_container) |
| assert len(results) == 0 # nothing to steram |
| result, state = streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| |
| |
| async def test_astream_result_halt_after_run_through_final_streaming(): |
| """Tests that we can pass through non-streaming results when streaming is called""" |
| action_tracker = CallCaptureTracker() |
| |
| counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") |
| counter_final_non_streaming = base_counter_action_async.with_name("counter_final_non_streaming") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| graph=Graph( |
| actions=[counter_non_streaming, counter_final_non_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_final_non_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = await app.astream_result( |
| halt_after=["counter_final_non_streaming"] |
| ) |
| results = [ |
| item async for item in streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 0 # nothing to stream |
| result, state = await streaming_container.get() |
| assert result["count"] == state["count"] == 11 |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| assert set(dict(action_tracker.pre_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert set(dict(action_tracker.post_called).keys()) == { |
| "counter_non_streaming", |
| "counter_final_non_streaming", |
| } |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 11) |
| ) # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| |
| |
| def test_stream_result_halt_before(): |
| action_tracker = CallCaptureTracker() |
| counter_non_streaming = base_counter_action.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter.with_name("counter_final") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| partition_key="test", |
| uid="test-123", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = app.stream_result(halt_after=[], halt_before=["counter_final"]) |
| results = list(streaming_container) |
| assert len(results) == 0 # nothing to steram |
| result, state = streaming_container.get() |
| assert action.name == "counter_final" # halt before this one |
| assert result is None |
| assert state["count"] == 10 |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| |
| |
| async def test_astream_result_halt_before(): |
| action_tracker = CallCaptureTracker() |
| counter_non_streaming = base_counter_action_async.with_name("counter_non_streaming") |
| counter_streaming = base_streaming_single_step_counter_async.with_name("counter_final") |
| |
| app = Application( |
| state=State({"count": 0}), |
| entrypoint="counter_non_streaming", |
| partition_key="test", |
| uid="test-123", |
| adapter_set=LifecycleAdapterSet(action_tracker), |
| graph=Graph( |
| actions=[counter_non_streaming, counter_streaming], |
| transitions=[ |
| Transition(counter_non_streaming, counter_non_streaming, expr("count < 10")), |
| Transition(counter_non_streaming, counter_streaming, default), |
| ], |
| ), |
| ) |
| action, streaming_container = await app.astream_result( |
| halt_after=[], halt_before=["counter_final"] |
| ) |
| results = [ |
| item async for item in streaming_container |
| ] # this should just have the intermediate results |
| assert len(results) == 0 # nothing to stream |
| result, state = await streaming_container.get() |
| assert action.name == "counter_final" # halt before this one |
| assert result is None |
| assert state["count"] == 10 |
| assert [item["sequence_id"] for _, item in action_tracker.pre_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| assert [item["sequence_id"] for _, item in action_tracker.post_called] == list( |
| range(0, 10) |
| ) # ensure sequence ID is respected |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| |
| |
| def test_app_set_state(): |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State(), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[Transition(counter_action, counter_action, default)], |
| ), |
| ) |
| assert "counter" not in app.state # initial value |
| app.step() |
| assert app.state["count"] == 1 # updated value |
| state = app.state |
| app.update_state(state.update(count=2)) |
| assert app.state["count"] == 2 # updated value |
| |
| |
| def test_app_get_next_step(): |
| counter_action_1 = base_counter_action.with_name("counter_1") |
| counter_action_2 = base_counter_action.with_name("counter_2") |
| counter_action_3 = base_counter_action.with_name("counter_3") |
| app = Application( |
| state=State(), |
| entrypoint="counter_1", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action_1, counter_action_2, counter_action_3], |
| transitions=[ |
| Transition(counter_action_1, counter_action_2, default), |
| Transition(counter_action_2, counter_action_3, default), |
| Transition(counter_action_3, counter_action_1, default), |
| ], |
| ), |
| ) |
| # uninitialized -- counter_1 |
| assert app.get_next_action().name == "counter_1" |
| app.step() |
| # ran counter_1 -- counter_2 |
| assert app.get_next_action().name == "counter_2" |
| app.step() |
| # ran counter_2 -- counter_3 |
| assert app.get_next_action().name == "counter_3" |
| app.step() |
| # ran counter_3 -- back to counter_1 |
| assert app.get_next_action().name == "counter_1" |
| |
| |
| def test_application_builder_complete(): |
| app = ( |
| ApplicationBuilder() |
| .with_state(count=0) |
| .with_actions(counter=base_counter_action, result=Result("count")) |
| .with_transitions( |
| ("counter", "counter", Condition.expr("count < 10")), ("counter", "result") |
| ) |
| .with_entrypoint("counter") |
| .build() |
| ) |
| graph = app.graph |
| assert len(graph.actions) == 2 |
| assert len(graph.transitions) == 2 |
| assert app.get_next_action().name == "counter" |
| |
| |
| def test__validate_start_valid(): |
| _validate_start("counter", {"counter", "result"}) |
| |
| |
| def test__validate_start_not_found(): |
| with pytest.raises(ValueError, match="not found"): |
| _validate_start("counter", {"result"}) |
| |
| |
| def test__adjust_single_step_output_result_and_state(): |
| state = State({"count": 1}) |
| result = {"count": 1} |
| assert _adjust_single_step_output((result, state), "test_action", DEFAULT_SCHEMA) == ( |
| result, |
| state, |
| ) |
| |
| |
| def test__adjust_single_step_output_just_state(): |
| state = State({"count": 1}) |
| assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == ( |
| {}, |
| state, |
| ) |
| |
| |
| def test__adjust_single_step_output_errors_incorrect_type(): |
| state = "foo" |
| with pytest.raises(ValueError, match="must return either"): |
| _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) |
| |
| |
| def test__adjust_single_step_output_errors_incorrect_result_type(): |
| state = State() |
| result = "bar" |
| with pytest.raises(ValueError, match="non-dict"): |
| _adjust_single_step_output((state, result), "test_action", DEFAULT_SCHEMA) |
| |
| |
| def test_application_builder_unset(): |
| with pytest.raises(ValueError): |
| ApplicationBuilder().build() |
| |
| |
| def test_application_run_step_hooks_sync(): |
| action_tracker = CallCaptureTracker() |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(action_tracker), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| app.run(halt_after=["result"]) |
| assert set(dict(action_tracker.pre_called).keys()) == {"counter", "result"} |
| assert set(dict(action_tracker.post_called).keys()) == {"counter", "result"} |
| # assert sequence id is incremented |
| assert action_tracker.pre_called[0][1]["sequence_id"] == 1 |
| assert action_tracker.post_called[0][1]["sequence_id"] == 1 |
| assert { |
| "action", |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| }.issubset(set(action_tracker.pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(action_tracker.post_called[0][1].keys())) |
| assert len(action_tracker.pre_called) == 11 |
| assert len(action_tracker.post_called) == 11 |
| # quick inclusion to ensure that the action is not called when we're done running |
| assert len(action_tracker.post_run_execute_calls) == 1 |
| assert len(action_tracker.pre_run_execute_calls) == 1 |
| assert app.step() is None # should be None |
| assert len(action_tracker.post_run_execute_calls) == 2 |
| assert len(action_tracker.pre_run_execute_calls) == 2 |
| |
| |
| async def test_application_run_step_hooks_async(): |
| tracker = ActionTrackerAsync() |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(tracker), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| await app.arun(halt_after=["result"]) |
| assert set(dict(tracker.pre_called).keys()) == {"counter", "result"} |
| assert set(dict(tracker.post_called).keys()) == {"counter", "result"} |
| # assert sequence id is incremented |
| assert tracker.pre_called[0][1]["sequence_id"] == 1 |
| assert tracker.post_called[0][1]["sequence_id"] == 1 |
| assert { |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| }.issubset(set(tracker.pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(tracker.post_called[0][1].keys())) |
| assert len(tracker.pre_called) == 11 |
| assert len(tracker.post_called) == 11 |
| |
| |
| async def test_application_run_step_runs_hooks(): |
| hooks = [CallCaptureTracker(), ActionTrackerAsync()] |
| |
| counter_action = base_counter_action.with_name("counter") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(*hooks), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action], |
| transitions=[ |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| await app.astep() |
| assert len(hooks[0].pre_called) == 1 |
| assert len(hooks[0].post_called) == 1 |
| # assert sequence id is incremented |
| assert hooks[0].pre_called[0][1]["sequence_id"] == 1 |
| assert hooks[0].post_called[0][1]["sequence_id"] == 1 |
| assert { |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| "action", |
| }.issubset(set(hooks[0].pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "state", |
| "inputs", |
| "app_id", |
| "partition_key", |
| }.issubset(set(hooks[1].pre_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(hooks[0].post_called[0][1].keys())) |
| assert { |
| "sequence_id", |
| "result", |
| "state", |
| "exception", |
| "app_id", |
| "partition_key", |
| }.issubset(set(hooks[1].post_called[0][1].keys())) |
| assert len(hooks[1].pre_called) == 1 |
| assert len(hooks[1].post_called) == 1 |
| assert len(hooks[0].post_run_execute_calls) == 1 |
| assert len(hooks[0].pre_run_execute_calls) == 1 |
| |
| |
| def test_application_post_application_create_hook(): |
| class PostApplicationCreateTracker(PostApplicationCreateHook): |
| def __init__(self): |
| self.called_args = None |
| self.call_count = 0 |
| |
| def post_application_create(self, **kwargs): |
| self.called_args = kwargs |
| self.call_count += 1 |
| |
| tracker = PostApplicationCreateTracker() |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| Application( |
| state=State({}), |
| entrypoint="counter", |
| adapter_set=internal.LifecycleAdapterSet(tracker), |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| assert "state" in tracker.called_args |
| assert "application_graph" in tracker.called_args |
| assert tracker.call_count == 1 |
| |
| |
| async def test_application_gives_graph(): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = Application( |
| state=State({}), |
| entrypoint="counter", |
| partition_key="test", |
| uid="test-123", |
| sequence_id=0, |
| graph=Graph( |
| actions=[counter_action, result_action], |
| transitions=[ |
| Transition(counter_action, result_action, Condition.expr("count >= 10")), |
| Transition(counter_action, counter_action, default), |
| ], |
| ), |
| ) |
| graph = app.graph |
| assert len(graph.actions) == 2 |
| assert len(graph.transitions) == 2 |
| assert graph.entrypoint.name == "counter" |
| |
| |
| def test_application_builder_initialize_does_not_allow_state_setting(): |
| with pytest.raises(ValueError, match="Cannot call initialize_from"): |
| ApplicationBuilder().with_entrypoint("foo").with_state(**{"foo": "bar"}).initialize_from( |
| DevNullPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| ) |
| |
| |
| class BrokenPersister(BaseStatePersister): |
| """Broken persistor.""" |
| |
| def load( |
| self, |
| partition_key: str, |
| app_id: Optional[str], |
| sequence_id: Optional[int] = None, |
| **kwargs, |
| ) -> Optional[PersistedStateData]: |
| return dict( |
| partition_key="key", |
| app_id="id", |
| sequence_id=0, |
| position="foo", |
| state=None, |
| created_at="", |
| status="completed", |
| ) |
| |
| def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: |
| return [] |
| |
| def save( |
| self, |
| partition_key: Optional[str], |
| app_id: str, |
| sequence_id: int, |
| position: str, |
| state: State, |
| status: Literal["completed", "failed"], |
| **kwargs, |
| ): |
| return |
| |
| |
| def test_application_builder_initialize_raises_on_broken_persistor(): |
| """Persisters should return None when there is no state to be loaded and the default used.""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| with pytest.raises(ValueError, match="but value for state was None"): |
| ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| BrokenPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| ) |
| .build() |
| ) |
| |
| |
| def test_load_from_sync_cannot_have_async_persistor_error(): |
| builder = ApplicationBuilder() |
| builder.initialize_from( |
| AsyncDevNullPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| ) |
| with pytest.raises( |
| ValueError, |
| match="are building the sync application, but have used an async initializer.", |
| ): |
| # we have not initialized |
| builder._load_from_sync_persister() |
| |
| |
| async def test_load_from_async_cannot_have_sync_persistor_error(): |
| await asyncio.sleep(0.00001) |
| builder = ApplicationBuilder() |
| builder.initialize_from( |
| DevNullPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| ) |
| with pytest.raises( |
| ValueError, |
| match="are building the async application, but have used an sync initializer.", |
| ): |
| # we have not initialized |
| await builder._load_from_async_persister() |
| |
| |
| def test_application_builder_assigns_correct_actions_with_dual_api(): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count") |
| |
| @action(reads=[], writes=[]) |
| def test_action(state: State) -> State: |
| return state |
| |
| app = ( |
| ApplicationBuilder() |
| .with_state(count=0) |
| .with_actions(counter_action, test_action, result=result_action) |
| .with_transitions() |
| .with_entrypoint("counter") |
| .build() |
| ) |
| graph = app.graph |
| assert {a.name for a in graph.actions} == {"counter", "result", "test_action"} |
| |
| |
| def test__validate_halt_conditions(): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count") |
| |
| @action(reads=[], writes=[]) |
| def test_action(state: State) -> State: |
| return state |
| |
| app = ( |
| ApplicationBuilder() |
| .with_state(count=0) |
| .with_actions(counter_action, test_action, result=result_action) |
| .with_transitions() |
| .with_entrypoint("counter") |
| .build() |
| ) |
| with pytest.raises(ValueError, match="(?=.*no_exist_1)(?=.*no_exist_2)"): |
| app._validate_halt_conditions(halt_after=["no_exist_1"], halt_before=["no_exist_2"]) |
| |
| |
| def test_application_builder_initialize_raises_on_fork_app_id_not_provided(): |
| """Can't pass in fork_from* without an app_id.""" |
| with pytest.raises(ValueError, match="If you set fork_from_partition_key"): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| BrokenPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| fork_from_sequence_id=1, |
| fork_from_partition_key="foo-bar", |
| ) |
| .build() |
| ) |
| |
| |
| class DummyPersister(BaseStatePersister): |
| """Dummy persistor.""" |
| |
| def load( |
| self, |
| partition_key: str, |
| app_id: Optional[str], |
| sequence_id: Optional[int] = None, |
| **kwargs, |
| ) -> Optional[PersistedStateData]: |
| return PersistedStateData( |
| partition_key="user123", |
| app_id="123", |
| sequence_id=5, |
| position="counter", |
| state=State({"count": 5}), |
| created_at="", |
| status="completed", |
| ) |
| |
| def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: |
| return ["123"] |
| |
| def save( |
| self, |
| partition_key: Optional[str], |
| app_id: str, |
| sequence_id: int, |
| position: str, |
| state: State, |
| status: Literal["completed", "failed"], |
| **kwargs, |
| ): |
| return |
| |
| |
| def test_application_builder_initialize_fork_errors_on_same_app_id(): |
| """Tests that we can't have an app_id and fork_from_app_id that's the same""" |
| with pytest.raises(ValueError, match="Cannot fork and save"): |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| DummyPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| fork_from_app_id="123", |
| fork_from_partition_key="user123", |
| ) |
| .with_identifiers(app_id="123") |
| .build() |
| ) |
| |
| |
| def test_application_builder_initialize_fork_app_id_happy_pth(): |
| """Tests that forking properly works""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| old_app_id = "123" |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| DummyPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="counter", |
| fork_from_app_id=old_app_id, |
| fork_from_partition_key="user123", |
| ) |
| .with_identifiers(app_id="test123") |
| .build() |
| ) |
| assert app.uid != old_app_id |
| assert app.state == State({"count": 5, "__PRIOR_STEP": "counter", "__SEQUENCE_ID": 5}) |
| assert app.parent_pointer.app_id == old_app_id |
| |
| |
| class NoOpTracker(SyncTrackingClient): |
| def copy(self): |
| pass |
| |
| def pre_start_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_stream_item( |
| self, |
| *, |
| item: Any, |
| item_index: int, |
| stream_initialize_time: datetime.datetime, |
| first_stream_item_start_time: datetime.datetime, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_end_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def do_log_attributes(self, **future_kwargs: Any): |
| pass |
| |
| def __init__(self, unique_id: str): |
| self.unique_id = unique_id |
| |
| def post_application_create(self, **future_kwargs: Any): |
| pass |
| |
| def pre_run_step(self, **future_kwargs: Any): |
| pass |
| |
| def post_run_step(self, **future_kwargs: Any): |
| pass |
| |
| def pre_start_span(self, **future_kwargs: Any): |
| pass |
| |
| def post_end_span(self, **future_kwargs: Any): |
| pass |
| |
| |
| def test_application_exposes_app_context(): |
| """Tests that we can get the context from the application correctly""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| context = app.context |
| assert context.app_id == "test123" |
| assert context.partition_key == "user123" |
| assert context.sequence_id == 5 |
| assert context.tracker.unique_id == "unique_tracker_name" |
| |
| |
| def test_application_exposes_app_context_through_context_manager_sync(): |
| """Tests that we can get the context from the application correctly""" |
| |
| @action(reads=["count"], writes=["count"]) |
| def counter(state: State) -> State: |
| app_context = ApplicationContext.get() |
| assert app_context is not None |
| assert app_context.tracker is not None # NoOpTracker is used |
| assert isinstance(app_context.sequence_id, int) |
| assert app_context.app_id == "test123" |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(result_action, counter=counter) |
| .with_transitions(("counter", "result", default)) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| app.run(halt_after=["result"]) |
| |
| |
| async def test_application_exposes_app_context_through_context_manager_async(): |
| """Tests that we can get the context from the application correctly""" |
| |
| @action(reads=["count"], writes=["count"]) |
| async def counter(state: State) -> State: |
| app_context = ApplicationContext.get() |
| assert app_context is not None |
| assert app_context.tracker is not None # NoOpTracker is used |
| assert isinstance(app_context.sequence_id, int) |
| assert app_context.app_id == "test123" |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(result_action, counter=counter) |
| .with_transitions(("counter", "result", default)) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| await app.arun(halt_after=["result"]) |
| |
| |
| def test_application_passes_context_when_declared(): |
| """Tests that the context is passed to the function correctly""" |
| context_list = [] |
| |
| @action(reads=["count"], writes=["count"]) |
| def context_counter(state: State, __context: ApplicationContext) -> State: |
| context_list.append(__context) |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter=context_counter, result=result_action) |
| .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| app.run(halt_after=["result"]) |
| sequence_ids = [context.sequence_id for context in context_list] |
| assert sequence_ids == list(range(6, 16)) |
| app_ids = set(context.app_id for context in context_list) |
| assert app_ids == {"test123"} |
| trackers = set(context.tracker.unique_id for context in context_list) |
| assert trackers == {"unique_tracker_name"} |
| |
| |
| def test_optional_context_in_dependency_factories(): |
| """Tests that the context is passed to the function correctly when nulled out. |
| TODO -- get this to test without instantiating an application through the builder -- |
| this is slightly overkill for a bit of code""" |
| context_list = [] |
| |
| @action(reads=["count"], writes=["count"]) |
| def context_counter(state: State, __context: ApplicationContext = None) -> State: |
| context_list.append(__context) |
| return state.update(count=state["count"] + 1) |
| |
| result_action = Result("count") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter=context_counter, result=result_action) |
| .with_transitions(("counter", "counter", expr("count < 10")), ("counter", "result")) |
| .with_tracker(NoOpTracker("unique_tracker_name")) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| inputs = app._process_inputs({}, app.get_next_action()) |
| assert "__context" in inputs # it should be there |
| assert inputs["__context"] is not None # it should not be None |
| assert inputs["__context"].app_id == "test123" # it should be the correct context |
| assert ( |
| inputs["__context"].tracker.unique_id == "unique_tracker_name" |
| ) # it should be the correct context |
| |
| |
| def test_application_with_no_spawning_parent(): |
| """Test that the application does not have a spawning parent when it is not specified""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .with_identifiers(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| spawned_by = app.spawning_parent_pointer |
| assert spawned_by is None |
| |
| |
| def test_application_with_spawning_parent(): |
| """Tests that the application builder can specify a spawning |
| parent and it gets wired through to the app.""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .with_identifiers(app_id="test123", partition_key="user123") |
| .with_spawning_parent(app_id="test123", partition_key="user123", sequence_id=5) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .build() |
| ) |
| spawned_by = app.spawning_parent_pointer |
| assert spawned_by is not None |
| assert spawned_by.app_id == "test123" |
| assert spawned_by.partition_key == "user123" |
| assert spawned_by.sequence_id == 5 |
| |
| |
| def test_application_does_not_allow_dunderscore_inputs(): |
| """Tests that the context is passed to the function correctly when nulled out. |
| TODO -- get this to test without instantiating an application through the builder -- |
| this is slightly overkill for a bit of code""" |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(result_action) |
| .with_entrypoint("result") |
| .with_transitions() |
| .with_state(count=0) |
| .build() |
| ) |
| with pytest.raises(ValueError, match="double underscore"): |
| app._process_inputs({"__not_allowed": ...}, app.get_next_action()) |
| |
| |
| def test_application_recursive_action_lifecycle_hooks(): |
| """Tests that calling burr within burr works as expected""" |
| |
| class TestingHook(PreApplicationExecuteCallHook, PostApplicationExecuteCallHook): |
| def __init__(self): |
| self.pre_called = [] |
| self.post_called = [] |
| |
| def pre_run_execute_call( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| state: "State", |
| method: ExecuteMethod, |
| **future_kwargs: Any, |
| ): |
| self.pre_called.append((app_id, partition_key, state, method)) |
| |
| def post_run_execute_call( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| state: "State", |
| method: ExecuteMethod, |
| exception: Optional[Exception], |
| **future_kwargs, |
| ): |
| self.post_called.append((app_id, partition_key, state, method, exception)) |
| |
| hook = TestingHook() |
| foo = [] |
| |
| @action( |
| reads=["recursion_count", "total_count"], |
| writes=["recursion_count", "total_count"], |
| ) |
| def recursive_action(state: State) -> State: |
| foo.append(1) |
| recursion_count = state["recursion_count"] |
| if recursion_count == 5: |
| return state |
| # Fork bomb! |
| total_counts = [] |
| for i in range(2): |
| app = ( |
| ApplicationBuilder() |
| .with_graph(graph) |
| .with_hooks(hook) |
| .with_entrypoint("recursive_action") |
| .with_state(recursion_count=recursion_count + 1, total_count=1) |
| .build() |
| ) |
| action, result, state = app.run(halt_after=["recursive_action"]) |
| total_counts.append(state["total_count"]) |
| return state.update( |
| recursion_count=state["recursion_count"], total_count=sum(total_counts) + 1 |
| ) |
| |
| graph = GraphBuilder().with_actions(recursive_action).with_transitions().build() |
| |
| # initial to kick it off |
| result = recursive_action(State({"recursion_count": 0, "total_count": 0})) |
| # Basic sanity checks to demonstrate |
| assert result["recursion_count"] == 5 |
| assert result["total_count"] == 63 # One for each of the calls (sum(2**n for n in range(6))) |
| assert ( |
| len(hook.pre_called) == 62 |
| ) # 63 - the initial one from the call to recursive_action outside the application |
| assert len(hook.post_called) == 62 # ditto |
| |
| |
| class CounterState(State): |
| count: int |
| |
| |
| class SimpleTypingSystem(TypingSystem[CounterState]): |
| def state_type(self) -> type[CounterState]: |
| return CounterState |
| |
| def state_pre_action_run_type(self, action: Action, graph: Graph) -> type[Any]: |
| raise NotImplementedError |
| |
| def state_post_action_run_type(self, action: Action, graph: Graph) -> type[Any]: |
| raise NotImplementedError |
| |
| def construct_data(self, state: State[Any]) -> CounterState: |
| return CounterState({"count": state["count"]}) |
| |
| def construct_state(self, data: Any) -> State[Any]: |
| raise NotImplementedError |
| |
| |
| def test_builder_captures_typing_system(): |
| """Tests that the typing system is captured correctly""" |
| counter_action = base_counter_action.with_name("counter") |
| result_action = Result("count").with_name("result") |
| app = ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "counter", expr("count < 10"))) |
| .with_transitions(("counter", "result", default)) |
| .with_entrypoint("counter") |
| .with_state(count=0) |
| .with_typing(SimpleTypingSystem()) |
| .build() |
| ) |
| assert isinstance(app.state.data, CounterState) |
| _, _, state = app.run(halt_after=["result"]) |
| assert isinstance(state.data, CounterState) |
| assert state.data["count"] == 10 |
| |
| |
| def test_set_sync_state_persister_cannot_have_async_error(): |
| builder = ApplicationBuilder() |
| persister = AsyncDevNullPersister() |
| builder.with_state_persister(persister) |
| with pytest.raises( |
| ValueError, |
| match="are building the sync application, but have used an async persister.", |
| ): |
| # we have not initialized |
| builder._set_sync_state_persister() |
| |
| |
| def test_set_sync_state_persister_is_not_initialized_error(tmp_path): |
| builder = ApplicationBuilder() |
| persister = SQLLitePersister(db_path=":memory:", table_name="test_table") |
| builder.with_state_persister(persister) |
| with pytest.raises(RuntimeError): |
| # we have not initialized |
| builder._set_sync_state_persister() |
| |
| |
| async def test_set_async_state_persister_cannot_have_sync_error(): |
| await asyncio.sleep(0.00001) |
| builder = ApplicationBuilder() |
| persister = DevNullPersister() |
| builder.with_state_persister(persister) |
| with pytest.raises( |
| ValueError, |
| match="are building the async application, but have used an sync persister.", |
| ): |
| # we have not initialized |
| await builder._set_async_state_persister() |
| |
| |
| async def test_set_async_state_persister_is_not_initialized_error(tmp_path): |
| await asyncio.sleep(0.00001) |
| builder = ApplicationBuilder() |
| |
| class FakePersister(AsyncDevNullPersister): |
| async def is_initialized(self): |
| return False |
| |
| persister = FakePersister() |
| builder.with_state_persister(persister) |
| with pytest.raises(RuntimeError): |
| # we have not initialized |
| await builder._set_async_state_persister() |
| |
| |
| def test_with_state_persister_is_initialized_not_implemented(): |
| builder = ApplicationBuilder() |
| |
| class FakePersister(BaseStatePersister): |
| # does not implement is_initialized |
| def list_app_ids(self): |
| return [] |
| |
| def save( |
| self, |
| partition_key: Optional[str], |
| app_id: str, |
| sequence_id: int, |
| position: str, |
| state: State, |
| status: Literal["completed", "failed"], |
| **kwargs, |
| ): |
| pass |
| |
| def load( |
| self, |
| partition_key: str, |
| app_id: Optional[str], |
| sequence_id: Optional[int] = None, |
| **kwargs, |
| ): |
| return None |
| |
| persister = FakePersister() |
| # Add the persister to the builder, expecting no exceptions |
| builder.with_state_persister(persister) |
| |
| |
| class ActionWithoutContext(Action): |
| def run(self, other_param, foo): |
| pass |
| |
| @property |
| def reads(self) -> list[str]: |
| pass |
| |
| @property |
| def writes(self) -> list[str]: |
| pass |
| |
| def update(self, result: dict, state: State) -> State: |
| pass |
| |
| def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: |
| return ["other_param", "foo"] |
| |
| |
| class ActionWithContext(ActionWithoutContext): |
| def run(self, __context, other_param, foo): |
| pass |
| |
| def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: |
| return ["other_param", "foo", "__context"] |
| |
| |
| class ActionWithKwargs(ActionWithoutContext): |
| def run(self, other_param, foo, **kwargs): |
| pass |
| |
| def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: |
| return ["other_param", "foo", "__context"] |
| |
| |
| class ActionWithContextTracer(ActionWithoutContext): |
| def run(self, __context, other_param, foo, __tracer): |
| pass |
| |
| def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: |
| return ["other_param", "foo", "__context", "__tracer"] |
| |
| |
| def test_remap_context_variable_with_mangled_context_kwargs(): |
| _action = ActionWithKwargs() |
| |
| inputs = { |
| "__context": "context_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| } |
| expected = { |
| "__context": "context_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| } |
| assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected |
| |
| |
| def test_remap_context_variable_with_mangled_context(): |
| _action = ActionWithContext() |
| |
| inputs = { |
| "__context": "context_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| } |
| expected = { |
| f"_{ActionWithContext.__name__}__context": "context_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| } |
| assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected |
| |
| |
| def test_remap_context_variable_with_mangled_contexttracer(): |
| _action = ActionWithContextTracer() |
| |
| inputs = { |
| "__context": "context_value", |
| "__tracer": "tracer_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| } |
| expected = { |
| f"_{ActionWithContextTracer.__name__}__context": "context_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| f"_{ActionWithContextTracer.__name__}__tracer": "tracer_value", |
| } |
| assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected |
| |
| |
| def test_remap_context_variable_without_mangled_context(): |
| _action = ActionWithoutContext() |
| inputs = { |
| "__context": "context_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| } |
| expected = { |
| "__context": "context_value", |
| "other_key": "other_value", |
| "foo": "foo_value", |
| } |
| assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected |
| |
| |
| async def test_async_application_builder_initialize_raises_on_broken_persistor(): |
| """Persisters should return None when there is no state to be loaded and the default used.""" |
| await asyncio.sleep(0.00001) |
| counter_action = base_counter_action_async.with_name("counter") |
| result_action = Result("count").with_name("result") |
| |
| class AsyncBrokenPersister(AsyncDevNullPersister): |
| async def load( |
| self, |
| partition_key: str, |
| app_id: Optional[str], |
| sequence_id: Optional[int] = None, |
| **kwargs, |
| ) -> Optional[PersistedStateData]: |
| await asyncio.sleep(0.0001) |
| return dict( |
| partition_key="key", |
| app_id="id", |
| sequence_id=0, |
| position="foo", |
| state=None, |
| created_at="", |
| status="completed", |
| ) |
| |
| with pytest.raises(ValueError, match="but value for state was None"): |
| await ( |
| ApplicationBuilder() |
| .with_actions(counter_action, result_action) |
| .with_transitions(("counter", "result", default)) |
| .initialize_from( |
| AsyncBrokenPersister(), |
| resume_at_next_action=True, |
| default_state={}, |
| default_entrypoint="foo", |
| ) |
| .abuild() |
| ) |
| |
| |
| def test_application__process_control_flow_params(): |
| @action(reads=[], writes=[], tags=["tag1", "tag2"]) |
| def test_action(state: State) -> State: |
| return state |
| |
| @action(reads=[], writes=[], tags=["tag1", "tag3"]) |
| def test_action_2(state: State) -> State: |
| return state |
| |
| app = ( |
| ApplicationBuilder() |
| .with_state(count=0) |
| .with_actions(test_action, test_action_2) |
| .with_transitions(("test_action", "test_action_2")) |
| .with_entrypoint("test_action") |
| .build() |
| ) |
| halt_before, halt_after, inputs = app._process_control_flow_params( |
| halt_after=["@tag:tag1"], halt_before=["@tag:tag2"] |
| ) |
| |
| assert sorted(halt_after) == ["test_action", "test_action_2"] |
| assert halt_before == ["test_action"] |
| assert inputs == {} |
| |
| |
| def test_initialize_from_applies_override_state_values(): |
| class FakeStateLoader(BaseStateLoader): |
| def load(self, partition_key, app_id, sequence_id): |
| return { |
| "state": State({"x": 1}), |
| "position": None, |
| "sequence_id": 0, |
| "status": "completed", |
| } |
| |
| def list_app_ids(self, partition_key): |
| return [] |
| |
| @action(reads=[], writes=[]) |
| def noop(state: State) -> State: |
| return state |
| |
| builder = ( |
| ApplicationBuilder() |
| .initialize_from( |
| initializer=FakeStateLoader(), |
| resume_at_next_action=False, |
| default_state={}, |
| default_entrypoint="noop", |
| override_state_values={"x": 100}, |
| ) |
| .with_actions(noop) |
| .with_transitions() |
| ) |
| |
| app = builder.build() |
| |
| assert app.state["x"] == 100 |