blob: 9930367ebcb3a429888736db10b01eaeb4e49274 [file]
import contextvars
import dataclasses
import functools
import logging
import pprint
import uuid
from contextlib import AbstractContextManager
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
Literal,
Optional,
Set,
Tuple,
Union,
cast,
)
from burr import telemetry, visibility
from burr.common import types as burr_types
from burr.core import persistence, validation
from burr.core.action import (
Action,
AsyncStreamingAction,
AsyncStreamingResultContainer,
Condition,
Function,
Reducer,
SingleStepAction,
SingleStepStreamingAction,
StreamingAction,
StreamingResultContainer,
)
from burr.core.graph import Graph, GraphBuilder
from burr.core.persistence import BaseStateLoader, BaseStateSaver
from burr.core.state import State
from burr.core.validation import BASE_ERROR_MESSAGE
from burr.lifecycle.base import LifecycleAdapter
from burr.lifecycle.internal import LifecycleAdapterSet
if TYPE_CHECKING:
from burr.tracking.base import TrackingClient
logger = logging.getLogger(__name__)
PRIOR_STEP = "__PRIOR_STEP"
SEQUENCE_ID = "__SEQUENCE_ID"
def _validate_result(result: dict, name: str) -> None:
if not isinstance(result, dict):
raise ValueError(
f"Action {name} returned a non-dict result: {result}. "
f"All results must be dictionaries."
)
def _raise_fn_return_validation_error(output: Any, action_name: str):
raise ValueError(
BASE_ERROR_MESSAGE
+ f"Single step action: {action_name} must return either *just* the newly updated State or a tuple of (result: dict, state: State). "
f"Got: {output} of type {type(output)} instead, which is not valid."
)
def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_name: str):
"""Adjusts the output of a single step action to be a tuple of (result, state) or just state"""
if isinstance(output, tuple):
if not len(output) == 2:
_raise_fn_return_validation_error(output, action_name)
_validate_result(output[0], action_name)
if not isinstance(output[1], State):
_raise_fn_return_validation_error(output, action_name)
return output
if isinstance(output, State):
return {}, output
_raise_fn_return_validation_error(output, action_name)
def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict:
"""Runs a function, returning the result of running the function.
Note this restricts the keys in the state to only those that the
function reads.
:param function: Function to run
:param state: State at time of execution
:param inputs: Inputs to the function
:return:
"""
if function.is_async():
raise ValueError(
f"Cannot run async: {name} "
"in non-async context. Use astep()/aiterate()/arun() "
"instead...)"
)
state_to_use = state.subset(*function.reads)
function.validate_inputs(inputs)
result = function.run(state_to_use, **inputs)
_validate_result(result, name)
return result
async def _arun_function(
function: Function, state: State, inputs: Dict[str, Any], name: str
) -> dict:
"""Runs a function, returning the result of running the function.
Async version of the above."""
state_to_use = state.subset(*function.reads)
function.validate_inputs(inputs)
result = await function.run(state_to_use, **inputs)
_validate_result(result, name)
return result
def _state_update(state_to_modify: State, modified_state: State) -> State:
"""This is a hack to apply state updates and ensure that we are respecting deletions. Specifically, the process is:
1. We subset the state to what we want to read
2. We perform a set of state-specific writes to it
3. We measure which ones were deleted
4. We then merge the whole state back in
5. We then delete the keys that were deleted
This is suboptimal -- we should not be observing the state, we should be using the state commands and layering in deltas.
That said, we currently eagerly evaluate the state at all operations, which means we have to do it this way. See
https://github.com/DAGWorks-Inc/burr/issues/33 for a more details plan.
This function was written to solve this issue: https://github.com/DAGWorks-Inc/burr/issues/28.
:param state_subset_pre_update: The subset of state passed to the update() function
:param modified_state: The subset of state realized after the update() function
:param state_to_modify: The state to modify-- this is the original
:return:
"""
old_state_keys = set(state_to_modify.keys())
new_state_keys = set(modified_state.keys())
deleted_keys = list(old_state_keys - new_state_keys)
# TODO -- unify the logic of choosing whether a key is internal or not
# Right now this is __sequence_id and __prior_step, but it could be more
deleted_keys_filtered = [item for item in deleted_keys if not item.startswith("__")]
return state_to_modify.merge(modified_state).wipe(delete=deleted_keys_filtered)
def _validate_reducer_writes(reducer: Reducer, state: State, name: str) -> None:
required_writes = reducer.writes
missing_writes = set(reducer.writes) - state.keys()
if len(missing_writes) > 0:
raise ValueError(
f"State is missing write keys after running: {name}. Missing keys are: {missing_writes}. "
f"Has writes: {required_writes}"
)
def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> State:
"""Runs the reducer, returning the new state. Note this restricts the
keys in the state to only those that the function writes.
:param reducer:
:param state:
:param result:
:return:
"""
# TODO -- better guarding on state reads/writes
new_state = reducer.update(result, state)
keys_in_new_state = set(new_state.keys())
new_keys = keys_in_new_state - set(state.keys())
extra_keys = new_keys - set(reducer.writes)
if len(extra_keys) > 0:
raise ValueError(
f"Action {name} attempted to write to keys {extra_keys} "
f"that it did not declare. It declared: ({reducer.writes})!"
)
_validate_reducer_writes(reducer, new_state, name)
return _state_update(state, new_state)
def _create_dict_string(kwargs: dict) -> str:
"""This is a utility function to create a string representation of a dict.
This is the state that was passed into the function usually. This is useful for debugging,
as it can be printed out to see what the state was.
:param kwargs: The inputs to the function that errored.
:return: The string representation of the inputs, truncated appropriately.
"""
pp = pprint.PrettyPrinter(width=80)
inputs = {}
for k, v in kwargs.items():
item_repr = repr(v)
if len(item_repr) > 50:
item_repr = item_repr[:50] + "..."
else:
item_repr = v
inputs[k] = item_repr
input_string = pp.pformat(inputs)
if len(input_string) > 1000:
input_string = input_string[:1000] + "..."
return input_string
def _format_BASE_ERROR_MESSAGE(action: Action, input_state: State, inputs: dict) -> str:
"""Formats the error string, given that we're inside an action"""
message = BASE_ERROR_MESSAGE
message += f"> Action: `{action.name}` encountered an error!"
padding = " " * (80 - len(message) - 1)
message += padding + "<"
message += "\n> State (at time of action):\n" + _create_dict_string(input_state.get_all())
message += "\n> Inputs (at time of action):\n" + _create_dict_string(inputs)
border = "*" * 80
return "\n" + border + "\n" + message + "\n" + border
def _run_single_step_action(
action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
) -> Tuple[Dict[str, Any], State]:
"""Runs a single step action. This API is internal-facing and a bit in flux, but
it corresponds to the SingleStepAction class.
:param action: Action to run
:param state: State to run with
:param inputs: Inputs to pass directly to the action
:return: The result of running the action, and the new state
"""
# TODO -- guard all reads/writes with a subset of the state
action.validate_inputs(inputs)
result, new_state = _adjust_single_step_output(
action.run_and_update(state, **inputs), action.name
)
_validate_result(result, action.name)
out = result, _state_update(state, new_state)
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return out
def _run_single_step_streaming_action(
action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
) -> Generator[Tuple[dict, Optional[State]], None, None]:
"""Runs a single step streaming action. This API is internal-facing.
This normalizes + validates the output."""
action.validate_inputs(inputs)
generator = action.stream_run_and_update(state, **inputs)
result = None
state_update = None
for item in generator:
if not isinstance(item, tuple):
# TODO -- consider adding support for just returning a result.
raise ValueError(
f"Action {action.name} must yield a tuple of (result, state_update). "
f"For all non-final results (intermediate),"
f"the state update must be None"
)
result, state_update = item
if state_update is None:
yield result, None
if state_update is None:
raise ValueError(
f"Action {action.name} did not return a state update. For streaming actions, the last yield "
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
_validate_result(result, action.name)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update
async def _arun_single_step_streaming_action(
action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Runs a single step streaming action in async. See the synchronous version for more details."""
action.validate_inputs(inputs)
generator = action.stream_run_and_update(state, **inputs)
result = None
state_update = None
async for item in generator:
if not isinstance(item, tuple):
# TODO -- consider adding support for just returning a result.
raise ValueError(
f"Action {action.name} must yield a tuple of (result, state_update). "
f"For all non-final results (intermediate),"
f"the state update must be None"
)
result, state_update = item
if state_update is None:
yield result, None
if state_update is None:
raise ValueError(
f"Action {action.name} did not return a state update. For async actions, the last yield "
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
_validate_result(result, action.name)
_validate_reducer_writes(action, state_update, action.name)
# TODO -- add guard against zero-length stream
yield result, state_update
def _run_multi_step_streaming_action(
action: StreamingAction, state: State, inputs: Optional[Dict[str, Any]]
) -> Generator[Tuple[dict, Optional[State]], None, None]:
"""Runs a multi-step streaming action. E.G. one with a run/reduce step.
This API is internal-facing. Note that this converts the shape of a
multi-step streaming action to yielding the results of the run step
as well as the state update, which is None for all the finaly ones.
This peeks ahead by one so we know when this is done (and when to validate).
"""
action.validate_inputs(inputs)
generator = action.stream_run(state, **inputs)
result = None
for item in generator:
# We want to peek ahead so we can return the last one
# This is slightly eager, but only in the case in which we
# are using a multi-step streaming action
next_result = result
result = item
if next_result is not None:
yield next_result, None
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update
async def _arun_multi_step_streaming_action(
action: AsyncStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Runs a multi-step streaming action in async. See the synchronous version for more details."""
action.validate_inputs(inputs)
generator = action.stream_run(state, **inputs)
result = None
async for item in generator:
# We want to peek ahead so we can return the last one
# This is slightly eager, but only in the case in which we
# are using a multi-step streaming action
next_result = result
result = item
if next_result is not None:
yield next_result, None
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update
async def _arun_single_step_action(
action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
) -> Tuple[dict, State]:
"""Runs a single step action in async. See the synchronous version for more details."""
state_to_use = state
action.validate_inputs(inputs)
result, new_state = _adjust_single_step_output(
await action.run_and_update(state_to_use, **inputs), action.name
)
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return result, _state_update(state, new_state)
@dataclasses.dataclass
class ApplicationGraph(Graph):
"""User-facing representation of the state machine. This has
#. All the action objects
#. All the transition objects
#. The entrypoint action
"""
entrypoint: Action
@dataclasses.dataclass
class ApplicationContext(AbstractContextManager):
"""Application context. This is anything your node might need to know about the application.
Often used for recursive tracking.
Note this is also a context manager (allowing you to pass context to sub-applications).
"""
app_id: str
partition_key: Optional[str]
sequence_id: Optional[int]
tracker: Optional["TrackingClient"]
@staticmethod
def get() -> Optional["ApplicationContext"]:
"""Provides the context-local application context.
You can use this instead of declaring `__context` in an application.
You really should only be using this if you're wiring through multiple layers of abstraction
and want to connect two applications.
:return: The ApplicationContext you'll want to use
"""
return _application_context.get()
def __enter__(self) -> "ApplicationContext":
_application_context.set(self)
return self
def __exit__(self, __exc_type, __exc_value, __traceback):
_application_context.set(None)
_application_context = contextvars.ContextVar[Optional[ApplicationContext]](
"application_context", default=None
)
class Application:
def __init__(
self,
graph: Graph,
state: State,
partition_key: Optional[str],
uid: str,
entrypoint: str,
sequence_id: Optional[int] = None,
adapter_set: Optional[LifecycleAdapterSet] = None,
builder: Optional["ApplicationBuilder"] = None,
fork_parent_pointer: Optional[burr_types.ParentPointer] = None,
spawning_parent_pointer: Optional[burr_types.ParentPointer] = None,
tracker: Optional["TrackingClient"] = None,
):
"""Instantiates an Application. This is an internal API -- use the builder!
:param actions: Actions to run
:param transitions: Transitions between actions
:param state: State to run with
:param initial_step: Step name to start at
:param partition_key: Partition key for the application (optional)
:param uid: Unique identifier for the application
:param sequence_id: Sequence ID for the application. Note this will be incremented every run.
So if this starts at 0, the first one you will see will be 1.
:param adapter_set: Set of lifecycle adapters
:param builder: Builder that created this application
"""
self._partition_key = partition_key
self._uid = uid
self.entrypoint = entrypoint
self._graph = graph
self._public_facing_graph = ApplicationGraph(
actions=graph.actions,
transitions=graph.transitions,
entrypoint=graph.get_action(entrypoint),
)
self._state = state
self._adapter_set = adapter_set if adapter_set is not None else LifecycleAdapterSet()
# TODO -- consider adding global inputs + global input factories to the builder
self._tracker = tracker
if sequence_id is not None:
self._set_sequence_id(sequence_id)
self._builder = builder
self._parent_pointer = fork_parent_pointer
self.dependency_factory = {
"__tracer": functools.partial(
visibility.tracing.TracerFactory, lifecycle_adapters=self._adapter_set
),
"__context": self._context_factory,
}
self._spawning_parent_pointer = spawning_parent_pointer
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_application_create",
state=self._state,
application_graph=self._public_facing_graph,
app_id=self._uid,
partition_key=self._partition_key,
parent_pointer=fork_parent_pointer,
spawning_parent_pointer=spawning_parent_pointer,
)
# @telemetry.capture_function_usage # todo -- capture usage when we break this up into one that isn't called internally
# This will be doable when we move sequence ID to the beginning of the function https://github.com/DAGWorks-Inc/burr/pull/73
def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action, dict, State]]:
"""Performs a single step, advancing the state machine along.
This returns a tuple of the action that was run, the result of running
the action, and the new state.
Use this if you just want to do something with the state and not rely on generators.
E.G. press forward/backwards, human in the loop, etc... Odds are this is not
the method you want -- you'll want iterate() (if you want to see the state/
results along the way), or run() (if you just want the final state/results).
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world
:return: Tuple[Function, dict, State] -- the function that was just ran, the result of running it, and the new state
"""
# we need to increment the sequence before we start computing
# that way if we're replaying from state, we don't get stuck
self._increment_sequence_id()
out = self._step(inputs=inputs, _run_hooks=True)
return out
def _context_factory(self, action: Action, sequence_id: int) -> ApplicationContext:
"""Helper function to create an application context, in the form of the dependency factories we inject to nodes."""
return ApplicationContext(
app_id=self._uid,
tracker=self._tracker,
partition_key=self._partition_key,
sequence_id=sequence_id,
)
def _step(
self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True
) -> Optional[Tuple[Action, dict, State]]:
"""Internal-facing version of step. This is the same as step, but with an additional
parameter to hide hook execution so async can leverage it."""
with self.context:
next_action = self.get_next_action()
if next_action is None:
return None
if inputs is None:
inputs = {}
action_inputs = self._process_inputs(inputs, next_action)
if _run_hooks:
self._adapter_set.call_all_lifecycle_hooks_sync(
"pre_run_step",
action=next_action,
state=self._state,
inputs=action_inputs,
sequence_id=self.sequence_id,
app_id=self._uid,
partition_key=self._partition_key,
_ProxyClassHook__tracer=self.dependency_factory["__tracer"],
)
exc = None
result = None
new_state = self._state
try:
if next_action.single_step:
result, new_state = _run_single_step_action(
next_action, self._state, action_inputs
)
else:
result = _run_function(
next_action, self._state, action_inputs, name=next_action.name
)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = self._update_internal_state_value(new_state, next_action)
self._set_state(new_state)
except Exception as e:
exc = e
logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs))
raise e
finally:
if _run_hooks:
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_run_step",
app_id=self._uid,
partition_key=self._partition_key,
action=next_action,
state=new_state,
result=result,
sequence_id=self.sequence_id,
exception=exc,
)
return next_action, result, new_state
def reset_to_entrypoint(self) -> None:
"""Resets the state machine to the entrypoint action -- you probably want to consider having a loop
in your graph, but this will do the trick if you need it!"""
self._set_state(self._state.wipe(delete=[PRIOR_STEP]))
def _update_internal_state_value(self, new_state: State, next_action: Action) -> State:
"""Updates the internal state values of the new state."""
new_state = new_state.update(
**{
PRIOR_STEP: next_action.name,
}
)
return new_state
def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, Any]:
"""Processes inputs, injecting the common inputs and ensuring that all required inputs are present."""
starting_with_double_underscore = {key for key in inputs.keys() if key.startswith("__")}
if len(starting_with_double_underscore) > 0:
raise ValueError(
BASE_ERROR_MESSAGE
+ f"Inputs starting with a double underscore ({starting_with_double_underscore}) "
f"are reserved for internal use/injected inputs."
"Please do not use keys"
)
inputs = inputs.copy()
processed_inputs = {}
required_inputs, optional_inputs = action.optional_and_required_inputs
for key in list(inputs.keys()):
if key in required_inputs or key in optional_inputs:
processed_inputs[key] = inputs.pop(key)
if len(inputs) > 0 and logger.isEnabledFor(logging.DEBUG):
logger.debug(
f"Keys {inputs.keys()} were passed in as inputs to action "
f"{action.name}, but not declared by the action as an input. "
f"Action only needs: {required_inputs} (optionally: {optional_inputs}) "
f"so we're just letting you know some inputs are being skipped."
)
missing_inputs = required_inputs - set(processed_inputs.keys())
additional_inputs = optional_inputs - set(processed_inputs.keys())
for input_ in missing_inputs | additional_inputs:
# if we can find it in the dependency factory, we'll use that
# TODO -- figure out what happens if people attempt to override default factory
# inputs
if input_ in self.dependency_factory:
processed_inputs[input_] = self.dependency_factory[input_](action, self.sequence_id)
if input_ in missing_inputs:
missing_inputs.remove(input_)
if len(missing_inputs) > 0:
missing_inputs_dict = {key: "FILL ME IN" for key in missing_inputs}
missing_inputs_dict.update({key: "..." for key in inputs.keys()})
addendum = (
"\nPlease double check the values passed to the keyword argument `inputs` to however you're running "
"the burr application.\n"
"e.g.\n"
f" app.run( # or app.step, app.iterate, app.astep, etc.\n"
f" halt_..., # your halt logic\n"
f" inputs={missing_inputs_dict} # <-- this is what you need to adjust\n"
f" )"
)
raise ValueError(
f"Action {action.name} is missing required inputs: {missing_inputs}. "
f"Has inputs: {processed_inputs}. " + addendum
)
return processed_inputs
# @telemetry.capture_function_usage
# ditto with step()
async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, dict, State]]:
"""Asynchronous version of step.
:param inputs: Inputs to the action -- this is if this action
requires an input that is passed in from the outside world
:return: Tuple[Function, dict, State] -- the action that was just ran, the result of running it, and the new state
"""
self._increment_sequence_id()
out = await self._astep(inputs=inputs, _run_hooks=True)
return out
async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True):
# we want to increment regardless of failure
with self.context:
next_action = self.get_next_action()
if next_action is None:
return None
if inputs is None:
inputs = {}
action_inputs = self._process_inputs(inputs, next_action)
if _run_hooks:
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"pre_run_step",
action=next_action,
state=self._state,
inputs=action_inputs,
sequence_id=self.sequence_id,
app_id=self._uid,
partition_key=self._partition_key,
)
exc = None
result = None
new_state = self._state
try:
if not next_action.is_async():
# we can just delegate to the synchronous version, it will block the event loop,
# but that's safer than assuming its OK to launch a thread
# TODO -- add an option/configuration to launch a thread (yikes, not super safe, but for a pure function
# which this is supposed to be its OK).
# this delegates hooks to the synchronous version, so we'll call all of them as well
return self._step(
inputs=action_inputs, _run_hooks=False
) # Skip hooks as we already ran all of them/will run all of them in this function's finally
if next_action.single_step:
result, new_state = await _arun_single_step_action(
next_action, self._state, inputs=action_inputs
)
else:
result = await _arun_function(
next_action, self._state, inputs=action_inputs, name=next_action.name
)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = self._update_internal_state_value(new_state, next_action)
self._set_state(new_state)
except Exception as e:
exc = e
logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs))
raise e
finally:
if _run_hooks:
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"post_run_step",
action=next_action,
state=new_state,
result=result,
sequence_id=self.sequence_id,
exception=exc,
app_id=self._uid,
partition_key=self._partition_key,
)
return next_action, result, new_state
def _clean_iterate_params(
self,
halt_before: list[str] = None,
halt_after: list[str] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Tuple[list[str], list[str], Dict[str, Any]]:
"""Utility function to clean out iterate params so we have less duplication between iterate/aiterate
and the logic is cleaner later.
"""
if halt_before is None and halt_after is None:
logger.warning(
"No halt termination specified -- this has the possibility of running forever!"
)
if halt_before is None:
halt_before = []
if halt_after is None:
halt_after = []
if inputs is None:
inputs = {}
return halt_before, halt_after, inputs
def _validate_halt_conditions(self, halt_before: list[str], halt_after: list[str]) -> None:
"""Utility function to validate halt conditions"""
missing_actions = set(halt_before + halt_after) - set(
action.name for action in self.graph.actions
)
if len(missing_actions) > 0:
raise ValueError(
BASE_ERROR_MESSAGE
+ f"Halt conditions {missing_actions} are not registered actions. Please ensure that they have been "
f"registered as actions in the application and that you've spelled them correctly!"
f"Valid actions are: {[action.name for action in self.graph.actions]}"
)
def has_next_action(self) -> bool:
"""Returns whether or not there is a next action to run.
:return: True if there is a next action, False otherwise
"""
return self.get_next_action() is not None
def _should_halt_iterate(
self, halt_before: list[str], halt_after: list[str], prior_action: Action
) -> bool:
"""Internal utility function to determine whether or not to halt during iteration"""
if self.has_next_action() and self.get_next_action().name in halt_before:
logger.debug(f"Halting before executing {self.get_next_action().name}")
return True
elif prior_action.name in halt_after:
logger.debug(f"Halting after executing {prior_action.name}")
return True
return False
def _return_value_iterate(
self,
halt_before: list[str],
halt_after: list[str],
prior_action: Optional[Action],
result: Optional[dict],
) -> Tuple[Optional[Action], Optional[dict], State]:
"""Utility function to decide what to return for iterate/arun. Note that run() will delegate to the return value of
iterate, whereas arun cannot delegate to the return value of aiterate (as async generators cannot return a value).
We put the code centrally to clean up the logic.
"""
if self.has_next_action() and self.get_next_action().name in halt_before:
logger.debug(
f"We have hit halt_before condition with next action: {self.get_next_action().name}. "
f"Returning: next_action={self.get_next_action()}, None, and state"
)
return self.get_next_action(), None, self._state
if prior_action is not None and prior_action.name in halt_after:
prior_action_name = prior_action.name if prior_action is not None else None
logger.debug(
f"We have hit halt_after condition with prior action: {prior_action_name}. "
f"Returning: prior_action={prior_action}, result, and state"
)
return prior_action, result, self._state
logger.warning(
"This is trying to return without having computed a single action -- "
"we'll end up just returning some Nones. This means that nothing was executed "
"(E.G. that the state machine had nowhere to go). Either fix the state machine or"
f"the halt conditions, or both... Halt conditions are: halt_before={halt_before}, halt_after={halt_after}."
f"Note that this is considered undefined behavior -- if you get here, you should fix!"
)
return prior_action, result, self._state
@telemetry.capture_function_usage
def iterate(
self,
*,
halt_before: list[str] = None,
halt_after: list[str] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Generator[Tuple[Action, dict, State], None, Tuple[Action, Optional[dict], State]]:
"""Returns a generator that calls step() in a row, enabling you to see the state
of the system as it updates. Note this returns a generator, and also the final result
(for convenience).
Note the nuance with halt_before and halt_after. halt_before conditions will take precedence to halt_after. Furthermore,
a single iteration will always be executed prior to testing for any halting conditions.
:param halt_before: The list of actions to halt before execution of. It will halt prior to the execution of the first one it sees.
:param halt_after: The list of actions to halt after execution of. It will halt after the execution of the first one it sees.
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world.
Note that this is only used for the first iteration -- subsequent iterations will not use this.
:return: Each iteration returns the result of running `step`. This generator also returns a tuple of
[action, result, current state]
"""
halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
self._validate_halt_conditions(halt_before, halt_after)
result = None
prior_action: Optional[Action] = None
while self.has_next_action():
# self.step will only return None if there is no next action, so we can rely on tuple unpacking
prior_action, result, state = self.step(inputs=inputs)
yield prior_action, result, state
if self._should_halt_iterate(halt_before, halt_after, prior_action):
break
return self._return_value_iterate(halt_before, halt_after, prior_action, result)
@telemetry.capture_function_usage
async def aiterate(
self,
*,
halt_before: list[str] = None,
halt_after: list[str] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[Tuple[Action, dict, State], None]:
"""Returns a generator that calls step() in a row, enabling you to see the state
of the system as it updates. This is the asynchronous version so it has no capability of t
:param halt_before: The list of actions to halt before execution of. It will halt on the first one.
:param halt_after: The list of actions to halt after execution of. It will halt on the first one.
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world.
Note that this is only used for the first iteration -- subsequent iterations will not use this.
:return: Each iteration returns the result of running `step`. This returns nothing -- it's an async generator which is not
allowed to have a return value.
"""
halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
self._validate_halt_conditions(halt_before, halt_after)
while self.has_next_action():
# self.step will only return None if there is no next action, so we can rely on tuple unpacking
prior_action, result, state = await self.astep(inputs=inputs)
yield prior_action, result, state
if self._should_halt_iterate(halt_before, halt_after, prior_action):
break
@telemetry.capture_function_usage
def run(
self,
*,
halt_before: list[str] = None,
halt_after: list[str] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Tuple[Action, Optional[dict], State]:
"""Runs your application through until completion. Does
not give access to the state along the way -- if you want that, use iterate().
:param halt_before: The list of actions to halt before execution of. It will halt on the first one.
:param halt_after: The list of actions to halt after execution of. It will halt on the first one.
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world.
Note that this is only used for the first iteration -- subsequent iterations will not use this.
:return: The final state, and the results of running the actions in the order that they were specified.
"""
gen = self.iterate(halt_before=halt_before, halt_after=halt_after, inputs=inputs)
while True:
try:
next(gen)
except StopIteration as e:
return e.value
@telemetry.capture_function_usage
async def arun(
self,
*,
halt_before: list[str] = None,
halt_after: list[str] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Tuple[Action, Optional[dict], State]:
"""Runs your application through until completion, using async. Does
not give access to the state along the way -- if you want that, use iterate().
:param halt_before: The list of actions to halt before execution of. It will halt on the first one.
:param halt_after: The list of actions to halt after execution of. It will halt on the first one.
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world
:return: The final state, and the results of running the actions in the order that they were specified.
"""
prior_action = None
result = None
halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
self._validate_halt_conditions(halt_before, halt_after)
async for prior_action, result, state in self.aiterate(
halt_before=halt_before, halt_after=halt_after, inputs=inputs
):
pass
return self._return_value_iterate(halt_before, halt_after, prior_action, result)
@telemetry.capture_function_usage
def stream_result(
self,
halt_after: list[str],
halt_before: list[str] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Tuple[Action, StreamingResultContainer]:
"""Streams a result out.
:param halt_after: The list of actions to halt after execution of. It will halt on the first one.
:param halt_before: The list of actions to halt before execution of. It will halt on the first one. Note that
if this is met, the streaming result container will be empty (and return None) for the result, having an empty generator.
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world
:return: A streaming result container, which is a generator that will yield results as they come in, as well as cache/give you the final result,
and update state accordingly.
This is meant to be used with streaming actions -- :py:meth:`streaming_action <burr.core.action.streaming_action>`
or :py:class:`StreamingAction <burr.core.action.StreamingAction>` It returns a
:py:class:`StreamingResultContainer <burr.core.action.StreamingResultContainer>`, which has two capabilities:
1. It is a generator that streams out the intermediate results of the action
2. It has a ``.get()`` method that returns the final result of the action, and the final state.
If ``.get()`` is called before the generator is exhausted, it will block until the generator is exhausted.
While this container is meant to work with streaming actions, it can also be used with non-streaming actions. In this case,
the generator will be empty, and the ``.get()`` method will return the final result and state.
The rules for halt_before and halt_after are the same as for :py:meth:`iterate <burr.core.application.Application.iterate>`,
and :py:meth:`run <burr.core.application.Application.run>`. In this case, `halt_before` will indicate a *non* streaming action,
which will be empty. Thus ``halt_after`` takes precedence -- if it is met, the streaming result container will contain the result of the
halt_after condition.
The :py:class:`StreamingResultContainer <burr.core.action.StreamingResultContainer>` is meant as a convenience -- specifically this allows for
hooks, callbacks, etc... so you can take the control flow and still have state updated afterwards. Hooks/state update will be called after an exception
is thrown during streaming, or the stream is completed. Note that it is undefined behavior to attempt to execute another action while a stream is in progress.
To see how this works, let's take the following action (simplified as a single-node workflow) as an example:
.. code-block:: python
@streaming_action(reads=[], writes=['response'])
def streaming_response(state: State, prompt: str) -> Generator[dict, None, Tuple[dict, State]]:
response = client.chat.completions.create(
model='gpt-3.5-turbo',
messages=[{
'role': 'user',
'content': prompt
}],
temperature=0,
)
buffer = []
for chunk in response:
delta = chunk.choices[0].delta.content
buffer.append(delta)
# yield partial results
yield {'response': delta}, None # indicate that we are not done by returning a `None` state!
full_response = ''.join(buffer)
# return the final result
yield {'response': full_response}, state.update(response=full_response)
To use streaming_result, you pass in names of streaming actions (such as the one above) to the halt_after
parameter:
.. code-block:: python
application = ApplicationBuilder().with_actions(streaming_response=streaming_response)...build()
prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."
action, streaming_result = application.stream_result(halt_after='streaming_response', inputs={"prompt": prompt})
for result in streaming_result:
print(result['response']) # one by one
result, state = streaming_result.get()
print(result) # all at once
Note that if you have multiple halt_after conditions, you can use the ``.action`` attribute to get the action that
was run.
.. code-block:: python
application = ApplicationBuilder().with_actions(
streaming_response=streaming_response,
error=error # another function that outputs an error, streaming
)...build()
prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."
action, streaming_result = application.stream_result(halt_after='streaming_response', inputs={"prompt": prompt})
color = "red" if action.name == "error" else "green"
for result in streaming_result:
print(format(result['response'], color)) # assumes that error and streaming_response both have the same output shape
.. code-block:: python
application = ApplicationBuilder().with_actions(
streaming_response=streaming_response,
error=non_streaming_error # a non-streaming function that outputs an error
)...build()
prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."
action, streaming_result = application.stream_result(halt_after='streaming_response', inputs={"prompt": prompt})
color = "red" if action.name == "error" else "green"
if action.name == "streaming_response": # can also use the ``.streaming`` attribute of action
for result in output:
print(format(result['response'], color)) # assumes that error and streaming_response both have the same output shape
else:
result, state = output.get()
print(format(result['response'], color))
"""
halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
self._validate_halt_conditions(halt_before, halt_after)
next_action = self.get_next_action()
if next_action is None:
raise ValueError(
f"Cannot stream result -- no next action found! Prior action was: {self._state[PRIOR_STEP]}"
)
if next_action.name not in halt_after:
# fast forward until we get to the action
# run already handles incrementing sequence IDs, nothing to worry about here
next_action, results, state = self.run(
halt_before=halt_after + halt_before, inputs=inputs
)
# In this case, we are ready to halt and return an empty generator
# The results will be None, and the state will be the final state
# For context, this is specifically for the case in which you want to have
# multiple terminal points with a unified API, where some are streaming, and some are not.
if next_action.name in halt_before and next_action.name not in halt_after:
return next_action, StreamingResultContainer.pass_through(
results=results, final_state=state
)
self._increment_sequence_id()
self._adapter_set.call_all_lifecycle_hooks_sync(
"pre_run_step",
action=next_action,
state=self._state,
inputs=inputs,
sequence_id=self.sequence_id,
app_id=self._uid,
partition_key=self._partition_key,
)
# we need to track if there's any exceptions that occur during this
try:
def process_result(result: dict, state: State) -> Tuple[Dict[str, Any], State]:
new_state = self._update_internal_state_value(state, next_action)
self._set_state(new_state)
return result, new_state
def callback(
result: Optional[dict],
state: State,
exc: Optional[Exception] = None,
):
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_run_step",
app_id=self._uid,
partition_key=self._partition_key,
action=next_action,
state=state,
result=result,
sequence_id=self.sequence_id,
exception=exc,
)
action_inputs = self._process_inputs(inputs, next_action)
if not next_action.streaming:
# In this case we are halting at a non-streaming condition
# This is allowed as we want to maintain a more consistent API
action, result, state = self._step(inputs=inputs, _run_hooks=False)
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_run_step",
app_id=self._uid,
partition_key=self._partition_key,
action=next_action,
state=self._state,
result=result,
sequence_id=self.sequence_id,
exception=None,
)
return action, StreamingResultContainer.pass_through(
results=result, final_state=state
)
if next_action.single_step:
next_action = cast(SingleStepStreamingAction, next_action)
generator = _run_single_step_streaming_action(
next_action, self._state, action_inputs
)
return next_action, StreamingResultContainer(
generator, self._state, process_result, callback
)
else:
next_action = cast(StreamingAction, next_action)
generator = _run_multi_step_streaming_action(
next_action, self._state, action_inputs
)
except Exception as e:
# We only want to raise this in the case of an exception
# otherwise, this will get delegated to the finally
# block of the streaming result container
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_run_step",
app_id=self._uid,
partition_key=self._partition_key,
action=next_action,
state=self._state,
result=None,
sequence_id=self.sequence_id,
exception=e,
)
raise
return next_action, StreamingResultContainer(
generator, self._state, process_result, callback
)
@telemetry.capture_function_usage
async def astream_result(
self,
halt_after: list[str],
halt_before: list[str] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Tuple[Action, AsyncStreamingResultContainer]:
"""Streams a result out in an asynchronous manner.
:param halt_after: The list of actions to halt after execution of. It will halt on the first one.
:param halt_before: The list of actions to halt before execution of. It will halt on the first one. Note that
if this is met, the streaming result container will be empty (and return None) for the result, having an empty generator.
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world
:return: An asynchronous :py:class:`AsyncStreamingResultContainer <burr.core.action.AsyncStreamingResultContainer>`, which is a generator that will yield results as they come in, as well as cache/give you the final result,
and update state accordingly.
This is meant to be used with streaming actions -- :py:meth:`streaming_action <burr.core.action.streaming_action>`
or :py:class:`StreamingAction <burr.core.action.StreamingAction>` It returns a
:py:class:`StreamingResultContainer <burr.core.action.StreamingResultContainer>`, which has two capabilities:
1. It is a generator that streams out the intermediate results of the action
2. It has an async ``.get()`` method that returns the final result of the action, and the final state.
If ``.get()`` is called before the generator is exhausted, it will block until the generator is exhausted.
While this container is meant to work with streaming actions, it can also be used with non-streaming actions. In this case,
the generator will be empty, and the ``.get()`` method will return the final result and state.
The rules for halt_before and halt_after are the same as for :py:meth:`iterate <burr.core.application.Application.iterate>`,
and :py:meth:`run <burr.core.application.Application.run>`. In this case, `halt_before` will indicate a *non* streaming action,
which will be empty. Thus ``halt_after`` takes precedence -- if it is met, the streaming result container will contain the result of the
halt_after condition.
The :py:class:`AsyncStreamingResultContainer <burr.core.action.StreamingResultContainer>` is meant as a convenience -- specifically this allows for
hooks, callbacks, etc... so you can take the control flow and still have state updated afterwards. Hooks/state update will be called after an exception
is thrown during streaming, or the stream is completed. Note that it is undefined behavior to attempt to execute another action while a stream is in progress.
To see how this works, let's take the following action (simplified as a single-node workflow) as an example:
.. code-block:: python
client = openai.AsyncClient()
@streaming_action(reads=[], writes=['response'])
async def streaming_response(state: State, prompt: str) -> Generator[dict, None, Tuple[dict, State]]:
response = client.chat.completions.create(
model='gpt-3.5-turbo',
messages=[{
'role': 'user',
'content': prompt
}],
temperature=0,
)
buffer = []
async for chunk in response: # use an async for loop
delta = chunk.choices[0].delta.content
buffer.append(delta)
# yield partial results
yield {'response': delta}, None # indicate that we are not done by returning a `None` state!
# make sure to join with the buffer!
full_response = ''.join(buffer)
# yield the final result at the end + the state update
yield {'response': full_response}, state.update(response=full_response)
To use streaming_result, you pass in names of streaming actions (such as the one above) to the halt_after
parameter:
.. code-block:: python
application = ApplicationBuilder().with_actions(streaming_response=streaming_response)...build()
prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."
action, streaming_result = application.astream_result(halt_after='streaming_response', inputs={"prompt": prompt})
async for result in streaming_result:
print(result['response']) # one by one
result, state = await streaming_result.get()
print(result['response']) # all at once
Note that if you have multiple halt_after conditions, you can use the ``.action`` attribute to get the action that
was run.
.. code-block:: python
application = ApplicationBuilder().with_actions(
streaming_response=streaming_response,
error=error # another function that outputs an error, streaming
)...build()
prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."
action, streaming_result = await application.astream_result(halt_after='streaming_response', inputs={"prompt": prompt})
color = "red" if action.name == "error" else "green"
for result in streaming_result:
print(format(result['response'], color)) # assumes that error and streaming_response both have the same output shape
.. code-block:: python
application = ApplicationBuilder().with_actions(
streaming_response=streaming_response,
error=non_streaming_error # a non-streaming function that outputs an error
)...build()
prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."
action, streaming_result = await application.astream_result(halt_after='streaming_response', inputs={"prompt": prompt})
color = "red" if action.name == "error" else "green"
if action.name == "streaming_response": # can also use the ``.streaming`` attribute of action
async for result in streaming_result:
print(format(result['response'], color)) # assumes that error and streaming_response both have the same output shape
else:
result, state = await output.get()
print(format(result['response'], color))
"""
halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
self._validate_halt_conditions(halt_before, halt_after)
next_action = self.get_next_action()
if next_action is None:
raise ValueError(
f"Cannot stream result -- no next action found! Prior action was: {self._state[PRIOR_STEP]}"
)
if next_action.name not in halt_after:
# fast forward until we get to the action
# run already handles incrementing sequence IDs, nothing to worry about here
next_action, results, state = await self.arun(
halt_before=halt_after + halt_before, inputs=inputs
)
# In this case, we are ready to halt and return an empty generator
# The results will be None, and the state will be the final state
# For context, this is specifically for the case in which you want to have
# multiple terminal points with a unified API, where some are streaming, and some are not.
if next_action.name in halt_before and next_action.name not in halt_after:
return next_action, AsyncStreamingResultContainer.pass_through(
results=results, final_state=state
)
self._increment_sequence_id()
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"pre_run_step",
action=next_action,
state=self._state,
inputs=inputs,
sequence_id=self.sequence_id,
app_id=self._uid,
partition_key=self._partition_key,
)
try:
def process_result(result: dict, state: State) -> Tuple[Dict[str, Any], State]:
new_state = self._update_internal_state_value(state, next_action)
self._set_state(new_state)
return result, new_state
async def callback(
result: Optional[dict],
state: State,
exc: Optional[Exception] = None,
):
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"post_run_step",
app_id=self._uid,
partition_key=self._partition_key,
action=next_action,
state=state,
result=result,
sequence_id=self.sequence_id,
exception=exc,
)
action_inputs = self._process_inputs(inputs, next_action)
if not next_action.streaming:
# In this case we are halting at a non-streaming condition
# This is allowed as we want to maintain a more consistent API
# TODO -- get this to work with async. Figure out how to run the async step...
action, result, state = await self._astep(inputs=inputs, _run_hooks=False)
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"post_run_step",
app_id=self._uid,
partition_key=self._partition_key,
action=next_action,
state=self._state,
result=result,
sequence_id=self.sequence_id,
exception=None,
)
return action, AsyncStreamingResultContainer.pass_through(
results=result, final_state=state
)
if next_action.single_step:
next_action = cast(SingleStepStreamingAction, next_action)
if not next_action.is_async():
raise ValueError(
f"Action: {next_action.name} is not"
"an async action, but is marked as streaming. "
"Currently we do not support running synchronous "
"streaming actions with astream_result, although we plan to in the future. "
"For now, convert the action to async. Please open up an issue if you hit this. "
)
generator = _arun_single_step_streaming_action(
next_action, self._state, action_inputs
)
return next_action, AsyncStreamingResultContainer(
generator, self._state, process_result, callback
)
else:
if not next_action.is_async():
raise ValueError(
f"Action: {next_action.name} is not"
"an async action, but is marked as streaming. "
"Currently we do not support running synchronous "
"streaming actions with astream_result, although we plan to in the future. "
"For now, convert the action to async. Please open up an issue if you hit this. "
)
next_action = cast(AsyncStreamingAction, next_action)
generator = _arun_multi_step_streaming_action(
next_action, self._state, action_inputs
)
except Exception as e:
# We only want to raise this in the case of an exception
# otherwise, this will get delegated to the finally
# block of the streaming result container
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_run_step",
app_id=self._uid,
partition_key=self._partition_key,
action=next_action,
state=self._state,
result=None,
sequence_id=self.sequence_id,
exception=e,
)
raise
return next_action, AsyncStreamingResultContainer(
generator, self._state, process_result, callback
)
@telemetry.capture_function_usage
def visualize(
self,
output_file_path: Optional[str] = None,
include_conditions: bool = False,
include_state: bool = False,
view: bool = False,
engine: Literal["graphviz"] = "graphviz",
**engine_kwargs: Any,
) -> Optional["graphviz.Digraph"]: # noqa: F821
"""Visualizes the application graph using graphviz. This will render the graph.
:param output_file_path: The path to save this to, None if you don't want to save. Do not pass an extension
for graphviz, instead pass `format` in `engine_kwargs` (e.g. `format="png"`)
:param include_conditions: Whether to include condition strings on the edges (this can get noisy)
:param include_state: Whether to indicate the action "signature" (reads/writes) on the nodes
:param view: Whether to bring up a view
:param engine: The engine to use -- only graphviz is supported for now
:param engine_kwargs: Additional kwargs to pass to the engine
:return: The graphviz object
"""
return self.graph.visualize(
output_file_path=output_file_path,
include_conditions=include_conditions,
include_state=include_state,
view=view,
engine=engine,
**engine_kwargs,
)
def _set_state(self, new_state: State):
self._state = new_state
def get_next_action(self) -> Optional[Action]:
return self._graph.get_next_node(self._state.get(PRIOR_STEP), self._state, self.entrypoint)
def update_state(self, new_state: State):
"""Updates state -- this is meant to be called if you need to do
anything with the state. For example:
1. Reset it (after going through a loop)
2. Store to some external source/log out
:param new_state:
:return:
"""
self._state = new_state
@property
def state(self) -> State:
"""Gives the state. Recall that state is purely immutable
-- anything you do with this state will not be persisted unless you
subsequently call update_state.
:return: The current state object.
"""
return self._state
@property
def parent_pointer(self) -> Optional[burr_types.ParentPointer]:
"""Gives the parent pointer of an application (from where it was forked).
This is None if it was not forked.
Forking is the process of starting an application off of another.
:return: The parent pointer object.
"""
return self._parent_pointer
@property
def spawning_parent_pointer(self) -> Optional[burr_types.ParentPointer]:
"""Gives the parent pointer of an application (from where it was spawned).
This is None if it was not spawned.
Spawning is the process of launching an application from within
a step of another. This is used for recursive tracking.
:return: The parent pointer object.
"""
return self._spawning_parent_pointer
@property
def graph(self) -> ApplicationGraph:
"""Application graph object -- if you want to inspect, visualize, etc..
this is what you want.
:return: The application graph object
"""
return self._public_facing_graph
@property
def sequence_id(self) -> Optional[int]:
"""gives the sequence ID of the current (next) action.
This is incremented prior to every step. Any logging, etc... will use the current
step's sequence ID
:return: The sequence ID of the current (next) action
"""
return self._state.get(SEQUENCE_ID)
@property
def context(self) -> ApplicationContext:
"""Gives the application context.
This has information you need for the tracker, sequence ID, application, etc...
:return: Application context
"""
return self._context_factory(self.get_next_action(), self.sequence_id)
def _increment_sequence_id(self):
if SEQUENCE_ID not in self._state:
self._state = self._state.update(**{SEQUENCE_ID: 0})
else:
self._state = self._state.update(**{SEQUENCE_ID: self.sequence_id + 1})
def _set_sequence_id(self, sequence_id: int):
self._state = self._state.update(**{SEQUENCE_ID: sequence_id})
@property
def uid(self) -> str:
"""Unique ID for the application. This must be unique across *all* applications in a search space.
This is used by persistence/tracking to ensure that applications have meanings.
Every application has this -- if not assigned, it will be randomly generated.
:return: The unique ID for the application
"""
return self._uid
@property
def partition_key(self) -> Optional[str]:
"""Partition key for the application. This is designed to add semantic meaning to
the application, and be leveraged by persistence systems to select/find applications.
Note this is optional -- if it is not included, you will need to use a persister that
supports a null partition key.
:return: The partition key, None if not set
"""
return self._partition_key
@property
def builder(self) -> Optional["ApplicationBuilder"]:
"""Returns the application builder that was used to build this application.
Note that this asusmes the application was built using the builder. Otherwise,
:return: The application builder
"""
return self._builder
def _repr_mimebundle_(self, include=None, exclude=None, **kwargs):
"""Attribute read by notebook renderers
This returns the attribute of the `graphviz.Digraph` returned by `self.display_all_functions()`
The parameters `include`, `exclude`, and `**kwargs` are required, but not explicitly used
ref: https://ipython.readthedocs.io/en/stable/config/integrating.html
"""
dot = self.visualize(include_conditions=True, include_state=False)
return dot._repr_mimebundle_(include=include, exclude=exclude, **kwargs)
def _validate_app_id(app_id: Optional[str]):
if app_id is None:
raise ValueError(
"App ID was None. Please ensure that you set an app ID using with_identifiers(app_id=...), or default"
"not setting it and letting the system generate one for you."
)
def _validate_start(start: Optional[str], actions: Set[str]):
validation.assert_set(start, "_start", "with_entrypoint")
if start not in actions:
raise ValueError(
f"Entrypoint: {start} not found in actions. Please add "
f"using with_actions({start}=...)"
)
class ApplicationBuilder:
def __init__(self):
self.start = None
self.state: Optional[State] = None
self.lifecycle_adapters: List[LifecycleAdapter] = list()
self.app_id: str = str(uuid.uuid4())
self.partition_key: Optional[str] = None
self.sequence_id: Optional[int] = None
self.initializer = None
self.use_entrypoint_from_save_state: Optional[bool] = None
self.default_state: Optional[dict] = None
self.fork_from_app_id: Optional[str] = None
self.fork_from_partition_key: Optional[str] = None
self.fork_from_sequence_id: Optional[int] = None
self.spawn_from_app_id: Optional[str] = None
self.spawn_from_partition_key: Optional[str] = None
self.spawn_from_sequence_id: Optional[int] = None
self.loaded_from_fork: bool = False
self.tracker = None
self.graph_builder = None
self.prebuilt_graph = None
def with_identifiers(
self, app_id: str = None, partition_key: str = None, sequence_id: int = None
) -> "ApplicationBuilder":
"""Assigns various identifiers to the application. This is used for tracking, persistence, etc...
:param app_id: Application ID -- this will be assigned to a uuid if not set.
:param partition_key: Partition key -- this is used for disambiguating groups of applications. For instance, a unique user ID, etc...
This is coupled to persistence, and is used to query for/select application runs.
:param sequence_id: Sequence ID that we want this to start at. If you're using ``.initialize``, this will be set. Otherwise this is
solely for resetting/starting at a specified position.
:return: The application builder for future chaining.
"""
if app_id is not None:
self.app_id = app_id
if partition_key is not None:
self.partition_key = partition_key
if sequence_id is not None:
self.sequence_id = sequence_id
return self
def with_state(self, **kwargs) -> "ApplicationBuilder":
"""Sets initial values in the state. If you want to load from a prior state,
you can do so here and pass the values in.
TODO -- enable passing in a `state` object instead of `**kwargs`
:param kwargs: Key-value pairs to set in the state
:return: The application builder for future chaining.
"""
if self.initializer is not None:
raise ValueError(
BASE_ERROR_MESSAGE + "You cannot set state if you are loading state"
"the .initialize_from() API. Either allow the persister to set the "
"state, or set the state manually."
)
if self.state is not None:
self.state = self.state.update(**kwargs)
else:
self.state = State(kwargs)
return self
def with_graph(self, graph: Graph) -> "ApplicationBuilder":
"""Adds a prebuilt graph -- this is an alternative to using the with_actions and with_transitions methods.
While you will likely use with_actions and with_transitions, you may want this in a few cases:
1. You want to reuse the same graph object for different applications
2. You want the logic that constructs the graph to be separate from that which constructs the application
3. You want to serialize/deserialize a graph object and run it in an application
:param graph: Graph object built with the :py:class:`GraphBuilder <burr.core.graph.GraphBuilder>`
:return: The application builder for future chaining.
"""
if self.graph_builder is not None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You have already called `with_actions`, or `with_transitions` -- you currently "
"cannot use the with_graph method along with that. Use `with_graph` or the other methods, not both"
)
self.prebuilt_graph = graph
return self
def _ensure_no_prebuilt_graph(self):
if self.prebuilt_graph is not None:
raise ValueError(
BASE_ERROR_MESSAGE + "You have already called `with_graph` -- you currently "
"cannot use the with_actions, or with_transitions method along with that. "
"Use `with_graph` or the other methods, not both."
)
return self
def _initialize_graph_builder(self):
if self.graph_builder is None:
self.graph_builder = GraphBuilder()
def with_entrypoint(self, action: str) -> "ApplicationBuilder":
"""Adds an entrypoint to the application. This is the action that will be run first.
This can only be called once.
:param action: The name of the action to set as the entrypoint
:return: The application builder for future chaining.
"""
if self.start is not None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You cannot set the entrypoint if you are loading a persister using "
"the .initialize_from() API. Either allow the persister to set the "
"entrypoint/provide a default, or set the entrypoint + state manually."
)
self.start = action
return self
def with_actions(
self, *action_list: Union[Action, Callable], **action_dict: Union[Action, Callable]
) -> "ApplicationBuilder":
"""Adds an action to the application. The actions are granted names (using the with_name)
method post-adding, using the kw argument. If it already has a name (or you wish to use the function name, raw, and
it is a function-based-action), then you can use the *args* parameter. This is the only supported way to add actions.
:param action_list: Actions to add -- these must have a name or be function-based (in which case we will use the function-name)
:param action_dict: Actions to add, keyed by name
:return: The application builder for future chaining.
"""
self._ensure_no_prebuilt_graph()
self._initialize_graph_builder()
self.graph_builder = self.graph_builder.with_actions(*action_list, **action_dict)
return self
def with_transitions(
self,
*transitions: Union[
Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition]
],
) -> "ApplicationBuilder":
"""Adds transitions to the application. Transitions are specified as tuples of either:
1. (from, to, condition)
2. (from, to) -- condition is set to DEFAULT (which is a fallback)
Transitions will be evaluated in order of specification -- if one is met, the others will not be evaluated.
Note that one transition can be terminal -- the system doesn't have
:param transitions: Transitions to add
:return: The application builder for future chaining.
"""
self._ensure_no_prebuilt_graph()
self._initialize_graph_builder()
self.graph_builder = self.graph_builder.with_transitions(*transitions)
return self
def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder":
"""Adds a lifecycle adapter to the application. This is a way to add hooks to the application so that
they are run at the appropriate times. You can use this to synchronize state out, log results, etc...
:param adapters: Adapter to add
:return: The application builder for future chaining.
"""
self.lifecycle_adapters.extend(adapters)
return self
def with_tracker(
self,
tracker: Union[Literal["local"], "TrackingClient"] = "local",
project: str = "default",
params: Dict[str, Any] = None,
):
"""Adds a "tracker" to the application. The tracker specifies
a project name (used for disambiguating groups of tracers), and plugs into the
Burr UI. This can either be:
1. A string (the only supported one right now is "local"), and a set of parameters for a set of supported trackers.
2. A lifecycle adapter object that does tracking (up to you how to implement it).
(1) internally creates a :py:class:`LocalTrackingClient <burr.tracking.client.LocalTrackingClient>` object, and adds it to the lifecycle adapters.
(2) adds the lifecycle adapter to the lifecycle adapters.
:param tracker: Tracker to use. ``local`` creates one, else pass one in.
:param project: Project name -- used if the tracker is string-specified (local).
:param params: Parameters to pass to the tracker if it's string-specified (local).
:return: The application builder for future chaining.
"""
# if it's a lifecycle adapter, just add it
instantiated_tracker = tracker
if isinstance(tracker, str):
if params is None:
params = {}
if tracker == "local":
from burr.tracking.client import LocalTrackingClient
kwargs = {"project": project}
kwargs.update(params)
instantiated_tracker = LocalTrackingClient(**kwargs)
self.lifecycle_adapters.append(instantiated_tracker)
else:
raise ValueError(f"Tracker {tracker}:{project} not supported")
else:
self.lifecycle_adapters.append(instantiated_tracker)
if params is not None:
raise ValueError(
"Params are not supported for object-specified trackers, these are already initialized!"
)
self.tracker = instantiated_tracker
return self
def initialize_from(
self,
initializer: BaseStateLoader,
resume_at_next_action: bool,
default_state: dict,
default_entrypoint: str,
fork_from_app_id: str = None,
fork_from_partition_key: str = None,
fork_from_sequence_id: int = None,
) -> "ApplicationBuilder":
"""Initializes the application we will build from some prior state object.
Note (1) that you can *either* call this or use `with_state` and `with_entrypoint`.
Note (2) if you want to continue a prior application and don't want to fork it into a new application ID,
the values in `.with_identifiers()` will be used to query for prior state.
:param initializer: The persister object to use for initialization. Likely the same one called with ``with_state_persister``.
:param resume_at_next_action: Whether to resume at the next action, or default to the ``default_entrypoint``
:param default_state: The default state to use if it does not exist. This is a dictionary.
:param default_entrypoint: The default entry point to use if it does not exist or you elect not to resume_at_next_action.
:param fork_from_app_id: The app ID to fork from, not to be confused with the current app_id that is set with `.with_identifiers()`. This is used to fork from a prior application run.
:param fork_from_partition_key: The partition key to fork from a prior application. Optional. `fork_from_app_id` required.
:param fork_from_sequence_id: The sequence ID to fork from a prior application run. Optional, defaults to latest. `fork_from_app_id` required.
:return: The application builder for future chaining.
"""
if self.start is not None or self.state is not None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "Cannot call initialize_from if you have already set state or an entrypoint! "
"You can either use the initializer *or* set the state and entrypoint manually."
)
if not fork_from_app_id and (fork_from_partition_key or fork_from_sequence_id):
raise ValueError(
BASE_ERROR_MESSAGE
+ "If you set fork_from_partition_key or fork_from_sequence_id, you must also set fork_from_app_id. "
"See .initialize_from() documentation."
)
self.initializer = initializer
self.resume_at_next_action = resume_at_next_action
self.default_state = default_state
self.start = default_entrypoint
self.fork_from_app_id = fork_from_app_id
self.fork_from_partition_key = fork_from_partition_key
self.fork_from_sequence_id = fork_from_sequence_id
return self
def with_state_persister(
self, persister: Union[BaseStateSaver, LifecycleAdapter], on_every: str = "step"
) -> "ApplicationBuilder":
"""Adds a state persister to the application. This is a way to persist state out to a database, file, etc...
at the specified interval. This is one of two options:
1. [normal mode] A BaseStateSaver object -- this is a utility class that makes it easy to save/load
2. [power-user-mode] A lifecycle adapter -- this is a custom class that you use to save state.
The framework will wrap the BaseStateSaver object in a PersisterHook, which is a post-run.
:param persister: The persister to add
:param on_every: The interval to persist state. Currently only "step" is supported.
:return: The application builder for future chaining.
"""
if on_every != "step":
raise ValueError(f"on_every {on_every} not supported")
if not isinstance(persister, persistence.BaseStateSaver):
self.lifecycle_adapters.append(persister)
else:
self.lifecycle_adapters.append(persistence.PersisterHook(persister))
return self
def with_spawning_parent(
self, app_id: str, sequence_id: int, partition_key: Optional[str] = None
) -> "ApplicationBuilder":
"""Sets the 'spawning' parent application that created this app.
This is used for tracking purposes. Doing this creates a parent/child relationship.
There can be many spawned children from a single sequence ID (just as there can be many forks of an app).
Note the difference between this and forking. Forking allows you to create a new app
where the old one left off. This suggests that this application is wholly contained
within the parent application.
:param app_id: ID of application that spawned this app
:param sequence_id: Sequence ID of the parent app that spawned this app
:param partition_key: Partition key of the parent app that spawned this app
:return: The application builder for future chaining.
"""
self.spawn_from_app_id = app_id
self.spawn_from_sequence_id = sequence_id
self.spawn_from_partition_key = partition_key
return self
def _load_from_persister(self):
"""Loads from the set persister and into this current object.
Mutates:
- self.state
- self.sequence_id
- maybe self.start
"""
if self.fork_from_app_id is not None:
if self.app_id == self.fork_from_app_id:
raise ValueError(
BASE_ERROR_MESSAGE + "Cannot fork and save to the same app_id. "
"Please update the app_id passed in via with_identifiers(), "
"or don't pass in a fork_from_app_id value to `initialize_from()`."
)
_partition_key = self.fork_from_partition_key
_app_id = self.fork_from_app_id
_sequence_id = self.fork_from_sequence_id
else:
# only use the with_identifier values if we're not forking from a previous app
_partition_key = self.partition_key
_app_id = self.app_id
_sequence_id = self.sequence_id
# load state from persister
load_result = self.initializer.load(_partition_key, _app_id, _sequence_id)
if load_result is None:
if self.fork_from_app_id is not None:
logger.warning(
f"{self.initializer.__class__.__name__} returned None while trying to fork from: "
f"partition_key:{_partition_key}, app_id:{_app_id}, "
f"sequence_id:{_sequence_id}. "
"You explicitly requested to fork from a prior application run, but it does not exist. "
"Defaulting to state defaults instead."
)
# there was nothing to load -- use default state
self.state = self.state.update(**self.default_state)
self.sequence_id = None # has to start at None
else:
self.loaded_from_fork = True
if load_result["state"] is None:
raise ValueError(
BASE_ERROR_MESSAGE
+ f"Error: {self.initializer.__class__.__name__} returned {load_result} for "
f"partition_key:{_partition_key}, app_id:{_app_id}, "
f"sequence_id:{_sequence_id}, "
"but value for state was None! This is not allowed. Please return just None in this case, "
"or double check that persisted state can never be a None value."
)
# TODO: capture parent app ID relationship & wire it through
# there was something
last_position = load_result["position"]
self.state = load_result["state"]
self.sequence_id = load_result["sequence_id"]
status = load_result["status"]
if self.resume_at_next_action:
# if we're supposed to resume where we saved from
if status == "completed":
# completed means we set prior step to current to go to next action
self.state = self.state.update(**{PRIOR_STEP: last_position})
else:
# else we failed we just start at that node
self.start = last_position
self.reset_to_entrypoint()
else:
# self.start is already set to the default. We don't need to do anything.
pass
def reset_to_entrypoint(self):
self.state = self.state.wipe(delete=[PRIOR_STEP])
def _get_built_graph(self) -> Graph:
if self.graph_builder is None and self.prebuilt_graph is None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You must set the graph using with_graph, or use with_entrypoint, with_actions, and with_transitions"
"to build the graph."
)
if self.graph_builder is not None:
return self.graph_builder.build()
return self.prebuilt_graph
@telemetry.capture_function_usage
def build(self) -> Application:
"""Builds the application.
This function is a bit messy as we iron out the exact logic and rigor we want around things.
:return: The application object
"""
_validate_app_id(self.app_id)
if self.state is None:
self.state = State()
if self.initializer:
# sets state, sequence_id, and maybe start
self._load_from_persister()
graph = self._get_built_graph()
_validate_start(self.start, {action.name for action in graph.actions})
return Application(
graph=graph,
state=self.state,
uid=self.app_id,
partition_key=self.partition_key,
sequence_id=self.sequence_id,
entrypoint=self.start,
adapter_set=LifecycleAdapterSet(*self.lifecycle_adapters),
builder=self,
fork_parent_pointer=burr_types.ParentPointer(
app_id=self.fork_from_app_id,
partition_key=self.fork_from_partition_key,
sequence_id=self.fork_from_sequence_id,
)
if self.loaded_from_fork
else None,
tracker=self.tracker,
spawning_parent_pointer=burr_types.ParentPointer(
app_id=self.spawn_from_app_id,
partition_key=self.spawn_from_partition_key,
sequence_id=self.spawn_from_sequence_id,
)
if self.spawn_from_app_id is not None
else None,
)