blob: fbe4a9b336f2cc20a9a3a56448ffa0eb2ea79737 [file]
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "af165506-cc16-485c-8875-f1ec55dd94b5",
"metadata": {},
"outputs": [],
"source": [
"!pip install burr[start] langchain_core langchain_community pydantic\n",
"# we only need these to show case the automatic serialization and deserialization that Burr has."
]
},
{
"cell_type": "markdown",
"id": "944a988132d19779",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Create a custom serializer and deserializer\n",
"\n",
"In this notebook you'll find an explanation of how to create a custom serializer and deserializer for state that Burr captures.\n",
"\n",
"Specifically there are four things to grok:\n",
"\n",
"1. Burr has some default type based serializers and deserializers that it uses to serialize and deserialize state. \n",
"2. You can register type based serializers and deserializers with Burr.\n",
"3. You can register state field level serializers and deserializers with Burr.\n",
"4. Burr assumes that the end result of the serialization and deserialization is a dictionary that can be JSON serialized and deserialized."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a727cdc01303cd39",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:26:39.555182Z",
"start_time": "2024-06-03T21:26:39.549366Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# import what we need\n",
"import pprint\n",
"import uuid\n",
"import pydantic\n",
"from langchain_core import documents\n",
"\n",
"from burr import core\n",
"from burr.core import State, action, expr, state, serde\n",
"from burr.tracking import client as tracking_client\n"
]
},
{
"cell_type": "markdown",
"id": "9a8cc495ca2cde46",
"metadata": {},
"source": [
"# Building a custom serializer and deserializer for a class type\n",
"Say we have have a custom class that we want to serialize and deserialize using custom serde.\n",
"\n",
"With Burr we just register a type based serializer and deserializer for the class.\n",
"\n",
"1. We first define the class we want to serialize and deserialize.\n",
"2. Then we define a serializer and deserializer for the class -- using the `@serde.serialize.register` and `@serde.deserializer.register` decorators to register the serializer and deserializer with Burr for that class.\n",
"\n",
"Note the function signatures, return types, and the decorators used."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7072f07c9a52cadc",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:26:40.139393Z",
"start_time": "2024-06-03T21:26:40.133785Z"
}
},
"outputs": [],
"source": [
"class CustomClass(object):\n",
" \"\"\"Custom class we'll use to test custom serde\"\"\"\n",
"\n",
" def __init__(self, value):\n",
" self.value = value\n",
"\n",
" def __eq__(self, other):\n",
" # required for asserts etc to work\n",
" return self.value == other.value\n",
"\n",
"\n",
"# create custom serializer for the custom class & register it\n",
"@serde.serialize.register(CustomClass)\n",
"def serialize_customclass(value: CustomClass, **kwargs) -> dict:\n",
" \"\"\"Serializes the custom class however we want\n",
"\n",
" :param value: the value to serialize.\n",
" :param kwargs:\n",
" :return: dictionary of serde.KEY and value\n",
" \"\"\"\n",
" return {\n",
" serde.KEY: \"CustomClass\", # this has to map to the value regisered for the deserializer below\n",
" \"value\": f\"[value=={value.value}]\", # serialize the value however we want\n",
" }\n",
"\n",
"\n",
"# create custom deserializer for the custom class & register it\n",
"@serde.deserializer.register(\"CustomClass\")\n",
"def deserialize_customclass(value: dict, **kwargs) -> CustomClass:\n",
" \"\"\"Deserializes the value using whatever meands we want\n",
"\n",
" :param value: the value to deserialize from.\n",
" :param kwargs:\n",
" :return: CustomClass\n",
" \"\"\"\n",
" # deserialize the value however we want\n",
" return CustomClass(value=value[\"value\"].split(\"==\")[1][0:-1])"
]
},
{
"cell_type": "markdown",
"id": "3080453eaf53979d",
"metadata": {},
"source": [
"# Field level serialization and deserialization\n",
"If we have to do something different from the usual type based serialization and deserialization, we can register field level serializers and deserializers.\n",
"\n",
"For some odd case you might have a field that needs to be serialized and deserialized differently but it's the same type as other fields so you can't just register a type based serializer and deserializer.\n",
"\n",
"This follows a similar approach to the above, except that we register the functions via `state.register_field_serde()`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e73d9d47fa95f3d0",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:26:40.408777Z",
"start_time": "2024-06-03T21:26:40.404637Z"
}
},
"outputs": [],
"source": [
"# create serializer\n",
"def my_field_serializer(value: documents.Document, **kwargs) -> dict:\n",
" \"\"\"Say we want to serialize langchain documents differently\"\"\"\n",
" serde_value = f\"serialized::{value.page_content}\"\n",
" return {\"value\": serde_value}\n",
"\n",
"# create deserializer\n",
"def my_field_deserializer(value: dict, **kwargs) -> documents.Document:\n",
" \"\"\"Say we want to deserialize langchain documents differently\"\"\"\n",
" serde_value = value[\"value\"]\n",
" return documents.Document(page_content=serde_value.replace(\"serialized::\", \"\"))\n",
"\n",
"# register the global field level serializer and deserializer\n",
"state.register_field_serde(\"custom_field\", my_field_serializer, my_field_deserializer)\n"
]
},
{
"cell_type": "markdown",
"id": "fbe6610566454f86",
"metadata": {},
"source": [
"# Define the actions\n",
"Let's now defin the application by starting with actions.\n",
"\n",
"This example shows that pydantic, and langchain documents are automatically serialized and deserialized by Burr."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a45e4d2e384eefc9",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:26:40.725841Z",
"start_time": "2024-06-03T21:26:40.721766Z"
}
},
"outputs": [],
"source": [
"@action(reads=[], writes=[\"dict\"])\n",
"def basic_action(state: State, user_input: str) -> tuple[dict, State]:\n",
" \"\"\"Action putting a bunch of things in state.\"\"\"\n",
" v = {\"foo\": 1, \"bar\": CustomClass(\"example value\"), \"bool\": True, \"None\": None, \"input\": user_input}\n",
" return {\"dict\": v}, state.update(dict=v)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ef21b3e77ce76325",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:26:40.989748Z",
"start_time": "2024-06-03T21:26:40.983982Z"
}
},
"outputs": [],
"source": [
"class PydanticField(pydantic.BaseModel):\n",
" \"\"\"burr handles serializing custom pydantic fields\"\"\"\n",
" f1: int = 0\n",
" f2: bool = False\n",
"\n",
"\n",
"@action(reads=[\"dict\"], writes=[\"pydantic_field\"])\n",
"def pydantic_action(state: State) -> tuple[dict, State]:\n",
" \"\"\"Action adding a pydantic field to state.\"\"\"\n",
" v = PydanticField(f1=state[\"dict\"][\"foo\"], f2=state[\"dict\"][\"bool\"])\n",
" return {\"pydantic_field\": v}, state.update(pydantic_field=v)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "af6944b0b1318060",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:26:41.259136Z",
"start_time": "2024-06-03T21:26:41.255104Z"
}
},
"outputs": [],
"source": [
"@action(reads=[\"pydantic_field\"], writes=[\"lc_doc\"])\n",
"def langchain_action(state: State) -> tuple[dict, State]:\n",
" \"\"\"Action adding a langchain document to state.\"\"\"\n",
" v = documents.Document(\n",
" page_content=f\"foo: {state['pydantic_field'].f1}, bar: {state['pydantic_field'].f2}\"\n",
" )\n",
" return {\"lc_doc\": v}, state.update(lc_doc=v)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9c53fb3fd6d2e1db",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:26:41.546076Z",
"start_time": "2024-06-03T21:26:41.542739Z"
}
},
"outputs": [],
"source": [
"@action(reads=[\"lc_doc\"], writes=[])\n",
"def terminal_action(state: State) -> tuple[dict, State]:\n",
" \"\"\"terminal action\"\"\"\n",
" return {\"output\": state[\"lc_doc\"].page_content}, state\n"
]
},
{
"cell_type": "markdown",
"id": "f0edd1299b33c16",
"metadata": {},
"source": [
"# Build the application"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "610b7e88581dedd9",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:30:38.475680Z",
"start_time": "2024-06-03T21:30:38.457744Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"215pt\" height=\"328pt\"\n",
" viewBox=\"0.00 0.00 214.62 327.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 323.5)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-323.5 210.62,-323.5 210.62,4 -4,4\"/>\n",
"<!-- basic_action -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>basic_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M125.88,-252.5C125.88,-252.5 66.38,-252.5 66.38,-252.5 60.38,-252.5 54.38,-246.5 54.38,-240.5 54.38,-240.5 54.38,-228.5 54.38,-228.5 54.38,-222.5 60.38,-216.5 66.38,-216.5 66.38,-216.5 125.88,-216.5 125.88,-216.5 131.88,-216.5 137.88,-222.5 137.88,-228.5 137.88,-228.5 137.88,-240.5 137.88,-240.5 137.88,-246.5 131.88,-252.5 125.88,-252.5\"/>\n",
"<text text-anchor=\"middle\" x=\"96.12\" y=\"-229.45\" font-family=\"Times,serif\" font-size=\"14.00\">basic_action</text>\n",
"</g>\n",
"<!-- pydantic_action -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>pydantic_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M95.25,-185.5C95.25,-185.5 17,-185.5 17,-185.5 11,-185.5 5,-179.5 5,-173.5 5,-173.5 5,-161.5 5,-161.5 5,-155.5 11,-149.5 17,-149.5 17,-149.5 95.25,-149.5 95.25,-149.5 101.25,-149.5 107.25,-155.5 107.25,-161.5 107.25,-161.5 107.25,-173.5 107.25,-173.5 107.25,-179.5 101.25,-185.5 95.25,-185.5\"/>\n",
"<text text-anchor=\"middle\" x=\"56.12\" y=\"-162.45\" font-family=\"Times,serif\" font-size=\"14.00\">pydantic_action</text>\n",
"</g>\n",
"<!-- basic_action&#45;&gt;pydantic_action -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>basic_action&#45;&gt;pydantic_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M85.41,-216.08C81.55,-209.82 77.1,-202.59 72.86,-195.69\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"75.89,-193.94 67.67,-187.26 69.93,-197.61 75.89,-193.94\"/>\n",
"</g>\n",
"<!-- terminal_action -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>terminal_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M133.88,-36C133.88,-36 56.38,-36 56.38,-36 50.38,-36 44.38,-30 44.38,-24 44.38,-24 44.38,-12 44.38,-12 44.38,-6 50.38,0 56.38,0 56.38,0 133.88,0 133.88,0 139.88,0 145.88,-6 145.88,-12 145.88,-12 145.88,-24 145.88,-24 145.88,-30 139.88,-36 133.88,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"95.12\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">terminal_action</text>\n",
"</g>\n",
"<!-- basic_action&#45;&gt;terminal_action -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>basic_action&#45;&gt;terminal_action</title>\n",
"<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M105.21,-216.24C109.4,-207.32 113.94,-196.11 116.12,-185.5 126.77,-133.92 128.36,-118.46 117.12,-67 115.62,-60.11 113.04,-53.04 110.16,-46.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"113.32,-45 105.83,-37.51 107.01,-48.04 113.32,-45\"/>\n",
"<text text-anchor=\"middle\" x=\"165.38\" y=\"-121.2\" font-family=\"Times,serif\" font-size=\"14.00\">dict[&#39;foo&#39;] == 0</text>\n",
"</g>\n",
"<!-- input__user_input -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>input__user_input</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"96.12\" cy=\"-301.5\" rx=\"73.87\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"96.12\" y=\"-296.45\" font-family=\"Times,serif\" font-size=\"14.00\">input: user_input</text>\n",
"</g>\n",
"<!-- input__user_input&#45;&gt;basic_action -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>input__user_input&#45;&gt;basic_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M96.12,-283.08C96.12,-277.25 96.12,-270.59 96.12,-264.14\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"99.63,-264.48 96.13,-254.48 92.63,-264.48 99.63,-264.48\"/>\n",
"</g>\n",
"<!-- langchain_action -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>langchain_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M96.25,-103C96.25,-103 12,-103 12,-103 6,-103 0,-97 0,-91 0,-91 0,-79 0,-79 0,-73 6,-67 12,-67 12,-67 96.25,-67 96.25,-67 102.25,-67 108.25,-73 108.25,-79 108.25,-79 108.25,-91 108.25,-91 108.25,-97 102.25,-103 96.25,-103\"/>\n",
"<text text-anchor=\"middle\" x=\"54.12\" y=\"-79.95\" font-family=\"Times,serif\" font-size=\"14.00\">langchain_action</text>\n",
"</g>\n",
"<!-- pydantic_action&#45;&gt;langchain_action -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>pydantic_action&#45;&gt;langchain_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M55.69,-149.03C55.44,-139.01 55.12,-126.18 54.84,-114.71\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"58.34,-114.86 54.6,-104.95 51.35,-115.03 58.34,-114.86\"/>\n",
"</g>\n",
"<!-- langchain_action&#45;&gt;terminal_action -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>langchain_action&#45;&gt;terminal_action</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M65.11,-66.58C69.06,-60.32 73.62,-53.09 77.97,-46.19\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"80.93,-48.07 83.3,-37.75 75,-44.34 80.93,-48.07\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x1229b27a0>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tracker = tracking_client.LocalTrackingClient(\"serde-example\")\n",
"app = (\n",
" core.ApplicationBuilder()\n",
" .with_actions(basic_action, pydantic_action, langchain_action, terminal_action)\n",
" .with_transitions(\n",
" (\"basic_action\", \"terminal_action\", expr(\"dict['foo'] == 0\")),\n",
" (\"basic_action\", \"pydantic_action\"),\n",
" (\"pydantic_action\", \"langchain_action\"),\n",
" (\"langchain_action\", \"terminal_action\"),\n",
" )\n",
" .with_identifiers(partition_key=\"user-1234\", app_id=str(uuid.uuid4()))\n",
" .initialize_from(\n",
" tracker,\n",
" resume_at_next_action=True,\n",
" default_state={\n",
" # this will use the custom field level serializer and deserializer\n",
" \"custom_field\": documents.Document(\n",
" page_content=\"this is a custom field to serialize\"\n",
" )\n",
" },\n",
" default_entrypoint=\"basic_action\",\n",
" )\n",
" .with_tracker(tracker)\n",
" .build()\n",
")\n",
"app.visualize(include_conditions=True)"
]
},
{
"cell_type": "markdown",
"id": "706603ee574c1f74",
"metadata": {},
"source": [
"# Run the application\n",
"Let's run it and then check the state to see how the custom serialization and deserialization worked."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f732761dc3f9406b",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:30:39.720570Z",
"start_time": "2024-06-03T21:30:39.706240Z"
}
},
"outputs": [],
"source": [
"action, result, state = app.run(halt_after=[\"terminal_action\"], inputs={\"user_input\": \"hello world\"})"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d393715f68255166",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:30:40.362605Z",
"start_time": "2024-06-03T21:30:40.356828Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'__PRIOR_STEP': 'terminal_action',\n",
" '__SEQUENCE_ID': 3,\n",
" 'custom_field': {'value': 'serialized::this is a custom field to serialize'},\n",
" 'dict': {'None': None,\n",
" 'bar': {'__burr_serde__': 'CustomClass',\n",
" 'value': '[value==example value]'},\n",
" 'bool': True,\n",
" 'foo': 1,\n",
" 'input': 'hello world'},\n",
" 'lc_doc': {'__burr_serde__': 'lc_document',\n",
" 'id': ['langchain', 'schema', 'document', 'Document'],\n",
" 'kwargs': {'page_content': 'foo: 1, bar: True', 'type': 'Document'},\n",
" 'lc': 1,\n",
" 'type': 'constructor'},\n",
" 'pydantic_field': {'__burr_serde__': 'pydantic',\n",
" '__pydantic_class': '__main__.PydanticField',\n",
" 'f1': 1,\n",
" 'f2': True}}\n"
]
}
],
"source": [
"# serialize\n",
"serialized_state = state.serialize()\n",
"pprint.pprint(serialized_state)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "707e0d00b19860a9",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-03T21:30:41.074787Z",
"start_time": "2024-06-03T21:30:41.069760Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/stefankrawczyk/.pyenv/versions/3.10.4/envs/burr-py310/lib/python3.10/site-packages/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `load` is in beta. It is actively being worked on, so the API may change.\n",
" warn_beta(\n"
]
}
],
"source": [
"# deserialize\n",
"deserialized_state = State.deserialize(serialized_state)\n",
"# assert that the state is the same after serialization and deserialization\n",
"assert state.get_all() == deserialized_state.get_all()"
]
},
{
"cell_type": "markdown",
"id": "679100f9851ea956",
"metadata": {},
"source": [
"# Go view the trace in the tracking client\n",
"You'll see that after each action the state is correctly serialized as we intended.\n",
"\n",
"[http://localhost:7241](http://localhost:7241)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb395188a89e62db",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}