blob: ec0511d8b00338518677c4c1a08d2641c4b48894 [file]
import pydantic
from langchain_core import documents
from burr import core
from burr.core import State, action, expr, serde, state
from burr.tracking import client as tracking_client
# say we have have a custom class that we want to serialize
# and deserialize using custom serde
class CustomClass(object):
"""Custom class we'll use to test custom serde"""
def __init__(self, value):
self.value = value
def __eq__(self, other):
# required for asserts etc to work
return self.value == other.value
# create custom serializer for the custom class & register it
@serde.serialize.register(CustomClass)
def serialize_customclass(value: CustomClass, **kwargs) -> dict:
"""Serializes the custom class however we want
:param value: the value to serialize.
:param kwargs:
:return: dictionary of serde.KEY and value
"""
return {
serde.KEY: "CustomClass", # this has to map to the value regisered for the deserializer below
"value": f"[value=={value.value}]", # serialize the value however we want
}
# create custom deserializer for the custom class & register it
@serde.deserializer.register("CustomClass")
def deserialize_customclass(value: dict, **kwargs) -> CustomClass:
"""Deserializes the value using whatever meands we want
:param value: the value to deserialize from.
:param kwargs:
:return: CustomClass
"""
# deserialize the value however we want
return CustomClass(value=value["value"].split("==")[1][0:-1])
# register custom serde for the custom field
def my_field_serializer(value: documents.Document, **kwargs) -> dict:
serde_value = f"serialized::{value.page_content}"
return {"value": serde_value}
def my_field_deserializer(value: dict, **kwargs) -> documents.Document:
serde_value = value["value"]
return documents.Document(page_content=serde_value.replace("serialized::", ""))
state.register_field_serde("custom_field", my_field_serializer, my_field_deserializer)
# --- define the actions
@action(reads=[], writes=["dict"])
def basic_action(state: State, user_input: str) -> tuple[dict, State]:
v = {
"foo": 1,
"bar": CustomClass("example value"),
"bool": True,
"None": None,
"input": user_input,
}
return {"dict": v}, state.update(dict=v)
class PydanticField(pydantic.BaseModel):
"""burr handles serializing custom pydantic fields"""
f1: int = 0
f2: bool = False
@action(reads=["dict"], writes=["pydantic_field"])
def pydantic_action(state: State) -> tuple[dict, State]:
v = PydanticField(f1=state["dict"]["foo"], f2=state["dict"]["bool"])
return {"pydantic_field": v}, state.update(pydantic_field=v)
@action(reads=["pydantic_field"], writes=["lc_doc"])
def langchain_action(state: State) -> tuple[dict, State]:
v = documents.Document(
page_content=f"foo: {state['pydantic_field'].f1}, bar: {state['pydantic_field'].f2}"
)
return {"lc_doc": v}, state.update(lc_doc=v)
@action(reads=["lc_doc"], writes=[])
def terminal_action(state: State) -> tuple[dict, State]:
return {"output": state["lc_doc"].page_content}, state
# build the application
def build_application(partition_key, app_id):
"""Builds the application"""
tracker = tracking_client.LocalTrackingClient("serde-example")
app = (
core.ApplicationBuilder()
.with_actions(basic_action, pydantic_action, langchain_action, terminal_action)
.with_transitions(
("basic_action", "terminal_action", expr("dict['foo'] == 0")),
("basic_action", "pydantic_action"),
("pydantic_action", "langchain_action"),
("langchain_action", "terminal_action"),
)
.with_identifiers(partition_key=partition_key, app_id=app_id)
.initialize_from(
tracker,
resume_at_next_action=True,
default_state={
"custom_field": documents.Document(
page_content="this is a custom field to serialize"
)
},
default_entrypoint="basic_action",
)
.with_tracker(tracker)
.build()
)
return app
if __name__ == "__main__":
"""
Note: if you try to use the test case creation commandline from the result
of running this application using the code below it will
fail because the above SERDE classes were registered/run as the `__main__`
module. You can either move the SERDE classes to a separate module so
that you could import them directly, or run the code below, but in another file (e.g. run.py).
"""
import pprint
import uuid
# build
app = build_application("client-123", str(uuid.uuid4()))
app.visualize(
output_file_path="statemachine", include_conditions=True, include_state=True, format="png"
)
# run
action, result, state = app.run(
halt_after=["terminal_action"], inputs={"user_input": "hello world"}
)
# serialize
serialized_state = state.serialize()
pprint.pprint(serialized_state)
# deserialize
deserialized_state = State.deserialize(serialized_state)
# assert that the state is the same after serialization and deserialization
assert state.get_all() == deserialized_state.get_all()