| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import asyncio |
| import concurrent.futures |
| import dataclasses |
| import datetime |
| from random import random |
| from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Literal, Optional, Union |
| |
| import pytest |
| |
| from burr.common import types as burr_types |
| from burr.core import ( |
| Action, |
| ApplicationBuilder, |
| ApplicationContext, |
| ApplicationGraph, |
| State, |
| action, |
| ) |
| from burr.core.action import Input, Result |
| from burr.core.graph import GraphBuilder |
| from burr.core.parallelism import ( |
| MapActions, |
| MapActionsAndStates, |
| MapStates, |
| RunnableGraph, |
| SubGraphTask, |
| TaskBasedParallelAction, |
| _cascade_adapter, |
| map_reduce_action, |
| ) |
| from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData |
| from burr.tracking.base import SyncTrackingClient |
| from burr.visibility import ActionSpan |
| |
| old_action = action |
| |
| |
| async def sleep_random(): |
| await asyncio.sleep(random() / 1000) |
| |
| |
| # Single action/callable subgraph |
| @action(reads=["input_number", "number_to_add"], writes=["output_number"]) |
| def simple_single_fn_subgraph( |
| state: State, additional_number: int = 1, identifying_number: int = 1000 |
| ) -> State: |
| return state.update( |
| output_number=state["input_number"] |
| + state["number_to_add"] |
| + additional_number |
| + identifying_number |
| ) |
| |
| |
| # Single action/callable subgraph |
| @action(reads=["input_number", "number_to_add"], writes=["output_number"]) |
| async def simple_single_fn_subgraph_async( |
| state: State, additional_number: int = 1, identifying_number: int = 1000 |
| ) -> State: |
| await sleep_random() |
| return state.update( |
| output_number=state["input_number"] |
| + state["number_to_add"] |
| + additional_number |
| + identifying_number |
| ) |
| |
| |
| class ClassBasedAction(Action): |
| def __init__(self, identifying_number: int, name: str = "class_based_action"): |
| super().__init__() |
| self._name = name |
| self.identifying_number = identifying_number |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_number", "number_to_add"] |
| |
| def run(self, state: State, **run_kwargs) -> dict: |
| return { |
| "output_number": state["input_number"] |
| + state["number_to_add"] |
| + run_kwargs.get("additional_number", 1) |
| + self.identifying_number |
| } |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_number"] |
| |
| def update(self, result: dict, state: State) -> State: |
| return state.update(**result) |
| |
| |
| class ClassBasedActionAsync(ClassBasedAction): |
| async def run(self, state: State, **run_kwargs) -> dict: |
| await sleep_random() |
| return super().run(state, **run_kwargs) |
| |
| |
| @action(reads=["input_number"], writes=["current_number"]) |
| def entry_action_for_subgraph(state: State) -> State: |
| return state.update(current_number=state["input_number"]) |
| |
| |
| @action(reads=["current_number", "number_to_add"], writes=["current_number"]) |
| def add_number_to_add(state: State) -> State: |
| return state.update(current_number=state["current_number"] + state["number_to_add"]) |
| |
| |
| @action(reads=["current_number"], writes=["current_number"]) |
| def add_additional_number_to_add( |
| state: State, additional_number: int = 1, identifying_number: int = 3000 |
| ) -> State: |
| return state.update( |
| current_number=state["current_number"] + additional_number + identifying_number |
| ) # 1000 is the one that marks this as different |
| |
| |
| @action(reads=["current_number"], writes=["output_number"]) |
| def final_result(state: State) -> State: |
| return state.update(output_number=state["current_number"]) |
| |
| |
| @action(reads=["input_number"], writes=["current_number"]) |
| async def entry_action_for_subgraph_async(state: State) -> State: |
| await sleep_random() |
| return entry_action_for_subgraph(state) |
| |
| |
| @action(reads=["current_number", "number_to_add"], writes=["current_number"]) |
| async def add_number_to_add_async(state: State) -> State: |
| await sleep_random() |
| return add_number_to_add(state) |
| |
| |
| @action(reads=["current_number"], writes=["current_number"]) |
| async def add_additional_number_to_add_async( |
| state: State, additional_number: int = 1, identifying_number: int = 3000 |
| ) -> State: |
| await sleep_random() |
| return add_additional_number_to_add( |
| state, additional_number=additional_number, identifying_number=identifying_number |
| ) # 1000 is the one that marks this as different |
| |
| |
| @action(reads=["current_number"], writes=["output_number"]) |
| async def final_result_async(state: State) -> State: |
| await sleep_random() |
| return final_result(state) |
| |
| |
| SubGraphType = Union[Action, Callable, RunnableGraph] |
| |
| |
| def create_full_subgraph(identifying_number: int = 0) -> SubGraphType: |
| return RunnableGraph( |
| graph=( |
| GraphBuilder() |
| .with_actions( |
| entry_action_for_subgraph, |
| add_number_to_add, |
| add_additional_number_to_add.bind(identifying_number=identifying_number), |
| final_result, |
| ) |
| .with_transitions( |
| ("entry_action_for_subgraph", "add_number_to_add"), |
| ("add_number_to_add", "add_additional_number_to_add"), |
| ("add_additional_number_to_add", "final_result"), |
| ) |
| .build() |
| ), |
| entrypoint="entry_action_for_subgraph", |
| halt_after=["final_result"], |
| ) |
| |
| |
| def create_full_subgraph_async(identifying_number: int = 0) -> SubGraphType: |
| return RunnableGraph( |
| graph=GraphBuilder() |
| .with_actions( |
| entry_action_for_subgraph=entry_action_for_subgraph_async, |
| add_number_to_add=add_number_to_add_async, |
| add_additional_number_to_add=add_additional_number_to_add_async.bind( |
| identifying_number=identifying_number |
| ), |
| final_result=final_result_async, |
| ) |
| .with_transitions( |
| ("entry_action_for_subgraph", "add_number_to_add"), |
| ("add_number_to_add", "add_additional_number_to_add"), |
| ("add_additional_number_to_add", "final_result"), |
| ) |
| .build(), |
| entrypoint="entry_action_for_subgraph", |
| halt_after=["final_result"], |
| ) |
| |
| |
| FULL_SUBGRAPH: SubGraphType = create_full_subgraph(identifying_number=3000) |
| FULL_SUBGRAPH_ASYNC: SubGraphType = create_full_subgraph_async(identifying_number=3000) |
| |
| |
| @dataclasses.dataclass |
| class RecursiveActionTracked: |
| state_before: Optional[State] |
| state_after: Optional[State] |
| action: Action |
| app_id: str |
| partition_key: str |
| sequence_id: int |
| children: List["RecursiveActionTracked"] = dataclasses.field(default_factory=list) |
| |
| |
| class RecursiveActionTracker(SyncTrackingClient): |
| """Simple test tracking client for a recursive action""" |
| |
| def __init__( |
| self, |
| events: List[RecursiveActionTracked], |
| parent: Optional["RecursiveActionTracker"] = None, |
| ): |
| self.events = events |
| self.parent = parent |
| |
| def copy(self): |
| """Quick way to copy from the current state. This assumes linearity (which is true in this case, as parallelism is delegated)""" |
| if self.events: |
| current_event = self.events[-1] |
| if current_event.state_after is not None: |
| raise ValueError("Don't copy if you're not in the middle of an event") |
| return RecursiveActionTracker(current_event.children, parent=self) |
| raise ValueError("Don't copy if you're not in the middle of an event") |
| |
| def post_application_create( |
| self, |
| *, |
| app_id: str, |
| partition_key: Optional[str], |
| state: "State", |
| application_graph: "ApplicationGraph", |
| parent_pointer: Optional[burr_types.ParentPointer], |
| spawning_parent_pointer: Optional[burr_types.ParentPointer], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def pre_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| inputs: Dict[str, Any], |
| **future_kwargs: Any, |
| ): |
| self.events.append( |
| RecursiveActionTracked( |
| state_before=state, |
| state_after=None, |
| action=action, |
| app_id=app_id, |
| partition_key=partition_key, |
| sequence_id=sequence_id, |
| ) |
| ) |
| |
| def post_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| result: Optional[Dict[str, Any]], |
| exception: Exception, |
| **future_kwargs: Any, |
| ): |
| self.events[-1].state_after = state |
| |
| def pre_start_span( |
| self, |
| *, |
| action: str, |
| action_sequence_id: int, |
| span: "ActionSpan", |
| span_dependencies: list[str], |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_end_span( |
| self, |
| *, |
| action: str, |
| action_sequence_id: int, |
| span: "ActionSpan", |
| span_dependencies: list[str], |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def do_log_attributes( |
| self, |
| *, |
| attributes: Dict[str, Any], |
| action: str, |
| action_sequence_id: int, |
| span: Optional["ActionSpan"], |
| tags: dict, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def pre_start_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_stream_item( |
| self, |
| *, |
| item: Any, |
| item_index: int, |
| stream_initialize_time: datetime.datetime, |
| first_stream_item_start_time: datetime.datetime, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_end_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| |
| def _group_events_by_app_id( |
| events: List[RecursiveActionTracked], |
| ) -> Dict[str, List[RecursiveActionTracked]]: |
| grouped_events = {} |
| for event in events: |
| if event.app_id not in grouped_events: |
| grouped_events[event.app_id] = [] |
| grouped_events[event.app_id].append(event) |
| return grouped_events |
| |
| |
| def test_map_actions_default_state(): |
| class MapActionsAllApproaches(MapActions): |
| def actions( |
| self, state: State, inputs: Dict[str, Any], context: ApplicationContext |
| ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: |
| ... |
| |
| def reduce(self, state: State, states: Generator[State, None, None]) -> State: |
| ... |
| |
| @property |
| def writes(self) -> list[str]: |
| return [] |
| |
| @property |
| def reads(self) -> list[str]: |
| return [] |
| |
| state_to_test = State({"foo": "bar", "baz": "qux"}) |
| assert MapActionsAllApproaches().state(state_to_test, {}).get_all() == state_to_test.get_all() |
| |
| |
| def test_e2e_map_actions_sync_subgraph(): |
| """Tests map actions over multiple action types (runnable graph, function, action class...)""" |
| |
| class MapActionsAllApproaches(MapActions): |
| def actions( |
| self, state: State, inputs: Dict[str, Any], context: ApplicationContext |
| ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: |
| for graph_ in [ |
| simple_single_fn_subgraph.bind(identifying_number=1000), |
| ClassBasedAction(2000), |
| create_full_subgraph(3000), |
| ]: |
| yield graph_ |
| |
| def state(self, state: State, inputs: Dict[str, Any]): |
| return state.update(input_number=state["input_number_in_state"], number_to_add=10) |
| |
| def reduce(self, state: State, states: Generator[State, None, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_number_in_state"] |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_number_in_state"), |
| map_action=MapActionsAllApproaches(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = app.run( |
| halt_after=["final_action"], inputs={"input_number_in_state": 100} |
| ) |
| assert state["output_numbers_in_state"] == [1111, 2111, 3111] # esnsure order correct |
| assert len(events) == 3 # three parent actions |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 3 # three unique App IDs, one for each launching subgraph |
| |
| |
| async def test_e2e_map_actions_async_subgraph(): |
| """Tests map actions over multiple action types (runnable graph, function, action class...)""" |
| |
| class MapActionsAllApproachesAsync(MapActions): |
| def actions( |
| self, state: State, inputs: Dict[str, Any], context: ApplicationContext |
| ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: |
| for graph_ in [ |
| simple_single_fn_subgraph_async.bind(identifying_number=1000), |
| ClassBasedActionAsync(2000), |
| create_full_subgraph_async(3000), |
| ]: |
| yield graph_ |
| |
| def is_async(self) -> bool: |
| return True |
| |
| def state(self, state: State, inputs: Dict[str, Any]): |
| return state.update(input_number=state["input_number_in_state"], number_to_add=10) |
| |
| async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| async for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_number_in_state"] |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_number_in_state"), |
| map_action=MapActionsAllApproachesAsync(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = await app.arun( |
| halt_after=["final_action"], inputs={"input_number_in_state": 100} |
| ) |
| assert state["output_numbers_in_state"] == [1111, 2111, 3111] # ensure order correct |
| assert len(events) == 3 # three parent actions |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 3 # three unique App IDs, one for each launching subgraph |
| |
| |
| @pytest.mark.parametrize( |
| "action", |
| [ |
| simple_single_fn_subgraph.bind(identifying_number=0), |
| ClassBasedAction(0), |
| create_full_subgraph(0), |
| ], |
| ) |
| def test_e2e_map_states_sync_subgraph(action: SubGraphType): |
| """Tests the map states action with a subgraph that is run in parallel. |
| Collatz conjecture over different starting points""" |
| |
| class MapStatesSync(MapStates): |
| def states( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[State, None, None]: |
| for input_number in state["input_numbers_in_state"]: |
| yield state.update(input_number=input_number, number_to_add=10) |
| |
| def action( |
| self, state: State, inputs: Dict[str, Any] |
| ) -> Union[Action, Callable, RunnableGraph]: |
| return action |
| |
| def is_async(self) -> bool: |
| return False |
| |
| def reduce(self, state: State, states: Generator[State, None, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_numbers_in_state"] |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_numbers_in_state"), |
| map_action=MapStatesSync(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = app.run( |
| halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} |
| ) |
| assert state["output_numbers_in_state"] == [111, 211, 311] # ensure order correct |
| assert len(events) == 3 |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 3 |
| |
| |
| @pytest.mark.parametrize( |
| "action", |
| [ |
| simple_single_fn_subgraph_async.bind(identifying_number=0), |
| ClassBasedActionAsync(0), |
| create_full_subgraph_async(0), |
| ], |
| ) |
| async def test_e2e_map_states_async_subgraph(action: SubGraphType): |
| """Tests the map states action with a subgraph that is run in parallel. |
| Collatz conjecture over different starting points""" |
| |
| class MapStatesAsync(MapStates): |
| def states( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[State, None, None]: |
| for input_number in state["input_numbers_in_state"]: |
| yield state.update(input_number=input_number, number_to_add=10) |
| |
| def action( |
| self, state: State, inputs: Dict[str, Any] |
| ) -> Union[Action, Callable, RunnableGraph]: |
| return action |
| |
| def is_async(self) -> bool: |
| return True |
| |
| async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| async for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_numbers_in_state"] |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_numbers_in_state"), |
| map_action=MapStatesAsync(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = await app.arun( |
| halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} |
| ) |
| assert state["output_numbers_in_state"] == [111, 211, 311] # ensure order correct |
| assert len(events) == 3 |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 3 |
| |
| |
| def test_e2e_map_actions_and_states_sync(): |
| """Tests the map states action with a subgraph that is run in parallel. |
| Collatz conjecture over different starting points""" |
| |
| class MapStatesSync(MapActionsAndStates): |
| def actions( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: |
| for graph_ in [ |
| simple_single_fn_subgraph.bind(identifying_number=1000), |
| ClassBasedAction(2000), |
| create_full_subgraph(3000), |
| ]: |
| yield graph_ |
| |
| def states( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[State, None, None]: |
| for input_number in state["input_numbers_in_state"]: |
| yield state.update(input_number=input_number, number_to_add=10) |
| |
| def is_async(self) -> bool: |
| return False |
| |
| def reduce(self, state: State, states: Generator[State, None, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_numbers_in_state"] |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_numbers_in_state"), |
| map_action=MapStatesSync(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = app.run( |
| halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} |
| ) |
| assert state["output_numbers_in_state"] == [ |
| 1111, |
| 1211, |
| 1311, |
| 2111, |
| 2211, |
| 2311, |
| 3111, |
| 3211, |
| 3311, |
| ] |
| assert len(events) == 3 |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states |
| |
| |
| async def test_e2e_map_actions_and_states_async(): |
| """Tests the map states action with a subgraph that is run in parallel. |
| Collatz conjecture over different starting points""" |
| |
| class MapStatesAsync(MapActionsAndStates): |
| def actions( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: |
| for graph_ in [ |
| simple_single_fn_subgraph_async.bind(identifying_number=1000), |
| ClassBasedActionAsync(2000), |
| create_full_subgraph_async(3000), |
| ]: |
| yield graph_ |
| |
| def states( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> AsyncGenerator[State, None]: |
| for input_number in state["input_numbers_in_state"]: |
| yield state.update(input_number=input_number, number_to_add=10) |
| |
| def is_async(self) -> bool: |
| return True |
| |
| async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| async for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_numbers_in_state"] |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_numbers_in_state"), |
| map_action=MapStatesAsync(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = await app.arun( |
| halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} |
| ) |
| assert state["output_numbers_in_state"] == [ |
| 1111, |
| 1211, |
| 1311, |
| 2111, |
| 2211, |
| 2311, |
| 3111, |
| 3211, |
| 3311, |
| ] |
| assert len(events) == 3 |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states |
| |
| |
| def test_task_level_API_e2e_sync(): |
| """Tests the map states action with a subgraph that is run in parallel. |
| Collatz conjecture over different starting points""" |
| |
| class TaskBasedAction(TaskBasedParallelAction): |
| def tasks( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[SubGraphTask, None, None]: |
| for j, action in enumerate( |
| [ |
| simple_single_fn_subgraph.bind(identifying_number=1000), |
| ClassBasedAction(2000), |
| create_full_subgraph(3000), |
| ] |
| ): |
| for i, input_number in enumerate(state["input_numbers_in_state"]): |
| yield SubGraphTask( |
| graph=RunnableGraph.create(action), |
| inputs={}, |
| state=state.update(input_number=input_number, number_to_add=10), |
| application_id=f"{i}_{j}", |
| tracker=context.tracker.copy(), |
| ) |
| |
| def reduce(self, state: State, states: Generator[State, None, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_numbers_in_state"] |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_numbers_in_state"), |
| map_action=TaskBasedAction(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = app.run( |
| halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} |
| ) |
| assert state["output_numbers_in_state"] == [ |
| 1111, |
| 1211, |
| 1311, |
| 2111, |
| 2211, |
| 2311, |
| 3111, |
| 3211, |
| 3311, |
| ] |
| assert len(events) == 3 |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states |
| |
| |
| async def test_task_level_API_e2e_async(): |
| """Tests the map states action with a subgraph that is run in parallel. |
| Collatz conjecture over different starting points""" |
| |
| class TaskBasedActionAsync(TaskBasedParallelAction): |
| async def tasks( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> AsyncGenerator[SubGraphTask, None]: |
| for j, action in enumerate( |
| [ |
| simple_single_fn_subgraph.bind(identifying_number=1000), |
| ClassBasedAction(2000), |
| create_full_subgraph(3000), |
| ] |
| ): |
| for i, input_number in enumerate(state["input_numbers_in_state"]): |
| yield SubGraphTask( |
| graph=RunnableGraph.create(action), |
| inputs={}, |
| state=state.update(input_number=input_number, number_to_add=10), |
| application_id=f"{i}_{j}", |
| tracker=context.tracker.copy(), |
| ) |
| |
| async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| async for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_numbers_in_state"] |
| |
| def is_async(self) -> bool: |
| return True |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_numbers_in_state"), |
| map_action=TaskBasedActionAsync(), |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = await app.arun( |
| halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} |
| ) |
| assert state["output_numbers_in_state"] == [ |
| 1111, |
| 1211, |
| 1311, |
| 2111, |
| 2211, |
| 2311, |
| 3111, |
| 3211, |
| 3311, |
| ] |
| assert len(events) == 3 |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states |
| |
| |
| def test_map_reduce_function_e2e(): |
| mre = map_reduce_action( |
| action=[ |
| simple_single_fn_subgraph.bind(identifying_number=1000), |
| ClassBasedAction(2000), |
| create_full_subgraph(3000), |
| ], |
| reads=["input_numbers_in_state"], |
| writes=["output_numbers_in_state"], |
| state=lambda state, context, inputs: ( |
| state.update(input_number=input_number, number_to_add=10) |
| for input_number in state["input_numbers_in_state"] |
| ), |
| inputs=[], |
| reducer=lambda state, states: state.extend( |
| output_numbers_in_state=[output_state["output_number"] for output_state in states] |
| ), |
| ) |
| |
| app = ( |
| ApplicationBuilder() |
| .with_actions( |
| initial_action=Input("input_numbers_in_state"), |
| map_action=mre, |
| final_action=Result("output_numbers_in_state"), |
| ) |
| .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) |
| .with_entrypoint("initial_action") |
| .with_tracker(RecursiveActionTracker(events := [])) |
| .build() |
| ) |
| action, result, state = app.run( |
| halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} |
| ) |
| assert state["output_numbers_in_state"] == [ |
| 1111, |
| 1211, |
| 1311, |
| 2111, |
| 2211, |
| 2311, |
| 3111, |
| 3211, |
| 3311, |
| ] |
| assert len(events) == 3 |
| _, map_event, __ = events |
| grouped_events = _group_events_by_app_id(map_event.children) |
| assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states |
| |
| |
| class DummyTracker(SyncTrackingClient): |
| def __init__(self, parent: Optional["DummyTracker"] = None): |
| self.parent = parent |
| |
| def copy(self): |
| return DummyTracker(parent=self) |
| |
| def post_application_create( |
| self, |
| *, |
| app_id: str, |
| partition_key: Optional[str], |
| state: "State", |
| application_graph: "ApplicationGraph", |
| parent_pointer: Optional[burr_types.ParentPointer], |
| spawning_parent_pointer: Optional[burr_types.ParentPointer], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def pre_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| inputs: Dict[str, Any], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_run_step( |
| self, |
| *, |
| app_id: str, |
| partition_key: str, |
| sequence_id: int, |
| state: "State", |
| action: "Action", |
| result: Optional[Dict[str, Any]], |
| exception: Exception, |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def pre_start_span( |
| self, |
| *, |
| action: str, |
| action_sequence_id: int, |
| span: "ActionSpan", |
| span_dependencies: list[str], |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_end_span( |
| self, |
| *, |
| action: str, |
| action_sequence_id: int, |
| span: "ActionSpan", |
| span_dependencies: list[str], |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def do_log_attributes( |
| self, |
| *, |
| attributes: Dict[str, Any], |
| action: str, |
| action_sequence_id: int, |
| span: Optional["ActionSpan"], |
| tags: dict, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def pre_start_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_stream_item( |
| self, |
| *, |
| item: Any, |
| item_index: int, |
| stream_initialize_time: datetime.datetime, |
| first_stream_item_start_time: datetime.datetime, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| def post_end_stream( |
| self, |
| *, |
| action: str, |
| sequence_id: int, |
| app_id: str, |
| partition_key: Optional[str], |
| **future_kwargs: Any, |
| ): |
| pass |
| |
| |
| class DummyPersister(BaseStateSaver, BaseStateLoader): |
| def __init__(self, parent: Optional["DummyPersister"] = None): |
| self.parent = parent |
| |
| def copy(self) -> "DummyPersister": |
| return DummyPersister(parent=self) |
| |
| def save( |
| self, |
| partition_key: Optional[str], |
| app_id: str, |
| sequence_id: int, |
| position: str, |
| state: State, |
| status: Literal["completed", "failed"], |
| **kwargs, |
| ): |
| pass |
| |
| def load( |
| self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs |
| ) -> Optional[PersistedStateData]: |
| pass |
| |
| def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: |
| pass |
| |
| |
| def test_cascade_adapter_cascade(): |
| # Tests that cascading the adapter results in a cloned adapter with `copy()` called |
| adapter = DummyTracker() |
| cascaded = _cascade_adapter("cascade", adapter) |
| assert cascaded.parent is adapter |
| |
| |
| def test_cascade_adapter_none(): |
| # Tests that setting the adapter behavior to None results in no adapter |
| adapter = DummyTracker() |
| cascaded = _cascade_adapter(None, adapter) |
| assert cascaded is None |
| |
| |
| def test_cascade_adapter_fixed(): |
| # Tests that setting the adapter behavior to a fixed value results in that value |
| current_adapter = DummyTracker() |
| next_adapter = DummyTracker() |
| cascaded = _cascade_adapter(next_adapter, current_adapter) |
| assert cascaded is next_adapter |
| |
| |
| def test_map_actions_and_states_uses_same_persister_as_loader(): |
| """This tests the MapActionsAndStates functionality of using the correct persister. Specifically |
| we want it to use the same instance for the saver as it does the loader, as that is |
| what the parent app does.""" |
| |
| class SimpleMapStates(MapActionsAndStates): |
| def actions( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: |
| for graph_ in [ |
| simple_single_fn_subgraph.bind(identifying_number=1000), |
| ]: |
| yield graph_ |
| |
| def states( |
| self, state: State, context: ApplicationContext, inputs: Dict[str, Any] |
| ) -> Generator[State, None, None]: |
| yield state.update(input_number=0, number_to_add=0) |
| |
| def reduce(self, state: State, states: Generator[State, None, None]) -> State: |
| # TODO -- ensure that states is in the correct order... |
| # Or decide to key it? |
| new_state = state |
| for output_state in states: |
| new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) |
| return new_state |
| |
| @property |
| def writes(self) -> list[str]: |
| return ["output_numbers_in_state"] |
| |
| @property |
| def reads(self) -> list[str]: |
| return ["input_numbers_in_state"] |
| |
| action = SimpleMapStates() |
| persister = DummyPersister() |
| tracker = DummyTracker() |
| |
| task_generator = action.tasks( |
| state=State(), |
| context=ApplicationContext( |
| app_id="app_id", |
| partition_key="partition_key", |
| sequence_id=0, |
| tracker=tracker, |
| state_persister=persister, |
| state_initializer=persister, |
| parallel_executor_factory=lambda: concurrent.futures.ThreadPoolExecutor(), |
| action_name=action.name, |
| ), |
| inputs={}, |
| ) |
| (task,) = task_generator # one task |
| assert task.state_persister is not None |
| assert task.state_initializer is not None |
| assert task.tracker is not None |
| assert task.state_persister is task.state_initializer # This ensures they're the same |