blob: 6d7a838aac64c965c7a269e194b241eea57d7b08 [file]
import pydantic
from langchain_core import documents
from burr import core
from burr.core import State, action, expr, persistence, state
from burr.tracking import client as tracking_client
@action(reads=[], writes=["dict"])
def basic_action(state: State, user_input: str) -> tuple[dict, State]:
v = {"foo": 1, "bar": "2", "bool": True, "None": None, "input": user_input}
return {"dict": v}, state.update(dict=v)
class PydanticField(pydantic.BaseModel):
f1: int = 0
f2: bool = False
@action(reads=["dict"], writes=["pydantic_field"])
def pydantic_action(state: State) -> tuple[dict, State]:
v = PydanticField(f1=state["dict"]["foo"], f2=state["dict"]["bool"])
return {"pydantic_field": v}, state.update(pydantic_field=v)
@action(reads=["pydantic_field"], writes=["lc_doc"])
def langchain_action(state: State) -> tuple[dict, State]:
v = documents.Document(
page_content=f"foo: {state['pydantic_field'].f1}, bar: {state['pydantic_field'].f2}"
)
return {"lc_doc": v}, state.update(lc_doc=v)
@action(reads=["lc_doc"], writes=[])
def terminal_action(state: State) -> tuple[dict, State]:
return {"output": state["lc_doc"].page_content}, state
def build_application(sqllite_persister, tracker, partition_key, app_id):
persister = sqllite_persister or tracker
app_builder = (
core.ApplicationBuilder()
.with_actions(basic_action, pydantic_action, langchain_action, terminal_action)
.with_transitions(
("basic_action", "terminal_action", expr("dict['foo'] == 0")),
("basic_action", "pydantic_action"),
("pydantic_action", "langchain_action"),
("langchain_action", "terminal_action"),
)
.with_identifiers(partition_key=partition_key, app_id=app_id)
.initialize_from(
persister,
resume_at_next_action=True,
default_state={
"custom_field": documents.Document(
page_content="this is a custom field to serialize"
)
},
default_entrypoint="basic_action",
)
)
if sqllite_persister:
app_builder.with_state_persister(sqllite_persister)
if tracker:
app_builder.with_tracker(tracker)
return app_builder.build()
def register_custom_serde_for_lc_document(field_name: str):
"""Register a custom serde for a field"""
def my_field_serializer(value: documents.Document, **kwargs) -> dict:
serde_value = f"serialized::{value.page_content}"
return {"value": serde_value}
def my_field_deserializer(value: dict, **kwargs) -> documents.Document:
serde_value = value["value"]
return documents.Document(page_content=serde_value.replace("serialized::", ""))
state.register_field_serde(field_name, my_field_serializer, my_field_deserializer)
def test_whole_application_tracker(tmp_path):
"""This test creates an application and then steps through it rebuilding the
application each time. This is a test of things being serialized and deserialized."""
tracker = tracking_client.LocalTrackingClient("integration-test", tmp_path)
app_id = "integration-test"
partition_key = ""
# step 1
app = build_application(None, tracker, partition_key, app_id)
register_custom_serde_for_lc_document("custom_field")
action1, result1, state1 = app.step(inputs={"user_input": "hello"})
# check custom serde
assert state1.serialize() == {
"__PRIOR_STEP": "basic_action",
"__SEQUENCE_ID": 0,
"custom_field": {"value": "serialized::this is a custom field to serialize"},
"dict": {"None": None, "bar": "2", "bool": True, "foo": 1, "input": "hello"},
}
assert action1.name == "basic_action"
# step 2
app = build_application(None, tracker, partition_key, app_id)
action2, result2, state2 = app.step()
assert action2.name == "pydantic_action"
# step 3
app = build_application(None, tracker, partition_key, app_id)
action3, result3, state3 = app.step()
assert action3.name == "langchain_action"
# step 4
app = build_application(None, tracker, partition_key, app_id)
action4, result4, state4 = app.step()
assert action4.name == "terminal_action"
# assert that state is basically the same across different steps
assert state1["dict"] == {"foo": 1, "bar": "2", "bool": True, "None": None, "input": "hello"}
assert state1["dict"] == state4["dict"]
assert state2["pydantic_field"].f1 == 1
assert state2["pydantic_field"].f2 is True
assert state2["pydantic_field"] == state3["pydantic_field"]
assert state3["lc_doc"].page_content == "foo: 1, bar: True"
assert state3["lc_doc"] == state4["lc_doc"]
# assert that tracker has things in it too
final_tracker_state = tracker.load(partition_key, app_id=app_id)
for k, v in final_tracker_state["state"].items():
assert v == state4[k]
def test_whole_application_sqllite(tmp_path):
"""This test creates an application and then steps through it rebuilding the
application each time. This is a test of things being serialized and deserialized."""
sqllite_persister = persistence.SQLLitePersister(tmp_path / "test.db")
sqllite_persister.initialize()
app_id = "integration-test"
partition_key = ""
# step 1
app = build_application(sqllite_persister, None, partition_key, app_id)
register_custom_serde_for_lc_document("custom_field")
action1, result1, state1 = app.step(inputs={"user_input": "hello"})
# check custom serde
assert state1.serialize() == {
"__PRIOR_STEP": "basic_action",
"__SEQUENCE_ID": 0,
"custom_field": {"value": "serialized::this is a custom field to serialize"},
"dict": {"None": None, "bar": "2", "bool": True, "foo": 1, "input": "hello"},
}
# check actions
assert action1.name == "basic_action"
# step 2
app = build_application(sqllite_persister, None, partition_key, app_id)
action2, result2, state2 = app.step()
assert action2.name == "pydantic_action"
# step 3
app = build_application(sqllite_persister, None, partition_key, app_id)
action3, result3, state3 = app.step()
assert action3.name == "langchain_action"
# step 4
app = build_application(sqllite_persister, None, partition_key, app_id)
action4, result4, state4 = app.step()
assert action4.name == "terminal_action"
# assert that state is basically the same across different steps
assert state1["dict"] == {"foo": 1, "bar": "2", "bool": True, "None": None, "input": "hello"}
assert state1["dict"] == state4["dict"]
assert state2["pydantic_field"].f1 == 1
assert state2["pydantic_field"].f2 is True
assert state2["pydantic_field"] == state3["pydantic_field"]
assert state3["lc_doc"].page_content == "foo: 1, bar: True"
assert state3["lc_doc"] == state4["lc_doc"]
final_sqllite_state = sqllite_persister.load("", app_id=app_id)
assert final_sqllite_state["state"] == state4
assert sqllite_persister.list_app_ids(partition_key="") == ["integration-test"]