blob: 5bb502551ff9526931f23e7f162236eab9356c7c [file]
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "1ce738e6-14b4-4892-9f2d-a1db94cfb29a",
"metadata": {},
"outputs": [],
"source": [
"!pip install burr[start]"
]
},
{
"cell_type": "markdown",
"id": "2118a2f4-d969-4771-81d3-156b432d1dc8",
"metadata": {},
"source": [
"# Streaming applications\n",
"\n",
"This shows how one goes about working with a streaming response with Burr, using FastAPI.\n",
"The code for implementation is in [application.py](application.py).\n",
"\n",
"This notebook only shows the streaming side. To check out FastAPI in Burr, check out\n",
"- The [Burr code](./application.py) -- imported and used here\n",
"- The [backend FastAPI server](./server.py) for the streaming output using SSE\n",
"- The [frontend typescript code](https://github.com/dagworks-inc/burr/blob/main/telemetry/ui/src/examples/StreamingChatbot.tsx) that renders and interacts with the stream\n",
"\n",
"You can view this demo in your app by running Burr:\n",
"\n",
"```bash\n",
"burr \n",
"```\n",
"\n",
"This will open a browser on [http://localhost:7241](http://localhost:7241)\n",
"\n",
"Then navigate to the [streaming example](http://localhost:7241/demos/streaming-chatbot)."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5b53de32-86af-475f-8976-e88a08986e34",
"metadata": {},
"outputs": [],
"source": [
"from application import application as streaming_application\n",
"from application import TERMINAL_ACTIONS\n",
"import pprint"
]
},
{
"cell_type": "markdown",
"id": "675a4245-816f-4d43-b412-307aee83db8e",
"metadata": {},
"source": [
"# The application\n",
"\n",
"The application we created will be a simple chatbot proxy. It has a few diffrent modes -- it can either decide a prompt is \"unsafe\" (in this case meaning that it has the word \"unsafe\" in it, but this would typically go to specific model),\n",
"or do one of the following:\n",
"\n",
"1. Generate code\n",
"2. Answer a question\n",
"3. Generate a poem\n",
"4. Prompt for more\n",
"\n",
"It will use an LLM to decide which to do. It streams back text using async streaming in Burr. Read more about how that is implemented [here](https://burr.dagworks.io/concepts/streaming-actions/).\n",
"\n",
"Note that, even though not every response is streaming (E.G. unsafe response, which is hardcoded), they are modeled as streaming to make interaction with the app simpler."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "81799877-8f0f-4572-9a96-f6ce6c430d9a",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 8.1.0 (20230707.0739)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"553pt\" height=\"304pt\"\n",
" viewBox=\"0.00 0.00 552.78 304.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 300)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-300 548.78,-300 548.78,4 -4,4\"/>\n",
"<!-- prompt -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M164.91,-231C164.91,-231 133.16,-231 133.16,-231 127.16,-231 121.16,-225 121.16,-219 121.16,-219 121.16,-207 121.16,-207 121.16,-201 127.16,-195 133.16,-195 133.16,-195 164.91,-195 164.91,-195 170.91,-195 176.91,-201 176.91,-207 176.91,-207 176.91,-219 176.91,-219 176.91,-225 170.91,-231 164.91,-231\"/>\n",
"<text text-anchor=\"middle\" x=\"149.03\" y=\"-207.95\" font-family=\"Times,serif\" font-size=\"14.00\">prompt</text>\n",
"</g>\n",
"<!-- check_safety -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>check_safety</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M216.28,-166C216.28,-166 153.78,-166 153.78,-166 147.78,-166 141.78,-160 141.78,-154 141.78,-154 141.78,-142 141.78,-142 141.78,-136 147.78,-130 153.78,-130 153.78,-130 216.28,-130 216.28,-130 222.28,-130 228.28,-136 228.28,-142 228.28,-142 228.28,-154 228.28,-154 228.28,-160 222.28,-166 216.28,-166\"/>\n",
"<text text-anchor=\"middle\" x=\"185.03\" y=\"-142.95\" font-family=\"Times,serif\" font-size=\"14.00\">check_safety</text>\n",
"</g>\n",
"<!-- prompt&#45;&gt;check_safety -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>prompt&#45;&gt;check_safety</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M158.87,-194.78C162.28,-188.81 166.19,-181.97 169.93,-175.43\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"173.28,-177.61 175.21,-167.19 167.21,-174.14 173.28,-177.61\"/>\n",
"</g>\n",
"<!-- input__prompt -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>input__prompt</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"149.03\" cy=\"-278\" rx=\"62.1\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"149.03\" y=\"-272.95\" font-family=\"Times,serif\" font-size=\"14.00\">input: prompt</text>\n",
"</g>\n",
"<!-- input__prompt&#45;&gt;prompt -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>input__prompt&#45;&gt;prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M149.03,-259.78C149.03,-254.23 149.03,-247.92 149.03,-241.8\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"152.53,-242.19 149.03,-232.19 145.53,-242.19 152.53,-242.19\"/>\n",
"</g>\n",
"<!-- unsafe_response -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>unsafe_response</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M154.66,-101C154.66,-101 73.41,-101 73.41,-101 67.41,-101 61.41,-95 61.41,-89 61.41,-89 61.41,-77 61.41,-77 61.41,-71 67.41,-65 73.41,-65 73.41,-65 154.66,-65 154.66,-65 160.66,-65 166.66,-71 166.66,-77 166.66,-77 166.66,-89 166.66,-89 166.66,-95 160.66,-101 154.66,-101\"/>\n",
"<text text-anchor=\"middle\" x=\"114.03\" y=\"-77.95\" font-family=\"Times,serif\" font-size=\"14.00\">unsafe_response</text>\n",
"</g>\n",
"<!-- check_safety&#45;&gt;unsafe_response -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>check_safety&#45;&gt;unsafe_response</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M165.63,-129.78C158.2,-123.19 149.56,-115.53 141.51,-108.39\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"144.21,-106.21 134.41,-102.19 139.57,-111.45 144.21,-106.21\"/>\n",
"</g>\n",
"<!-- decide_mode -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>decide_mode</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M396.41,-101C396.41,-101 331.66,-101 331.66,-101 325.66,-101 319.66,-95 319.66,-89 319.66,-89 319.66,-77 319.66,-77 319.66,-71 325.66,-65 331.66,-65 331.66,-65 396.41,-65 396.41,-65 402.41,-65 408.41,-71 408.41,-77 408.41,-77 408.41,-89 408.41,-89 408.41,-95 402.41,-101 396.41,-101\"/>\n",
"<text text-anchor=\"middle\" x=\"364.03\" y=\"-77.95\" font-family=\"Times,serif\" font-size=\"14.00\">decide_mode</text>\n",
"</g>\n",
"<!-- check_safety&#45;&gt;decide_mode -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>check_safety&#45;&gt;decide_mode</title>\n",
"<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M228.36,-131.75C252.66,-123.2 283.21,-112.45 309.16,-103.31\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"309.96,-106.39 318.23,-99.77 307.64,-99.79 309.96,-106.39\"/>\n",
"</g>\n",
"<!-- unsafe_response&#45;&gt;prompt -->\n",
"<g id=\"edge16\" class=\"edge\">\n",
"<title>unsafe_response&#45;&gt;prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M117.5,-101.42C120.96,-118.11 126.66,-143.96 133.03,-166 134.76,-171.99 136.82,-178.35 138.88,-184.36\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"135.89,-185.4 142.51,-193.67 142.49,-183.07 135.89,-185.4\"/>\n",
"</g>\n",
"<!-- generate_code -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>generate_code</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M160.03,-36C160.03,-36 90.03,-36 90.03,-36 84.03,-36 78.03,-30 78.03,-24 78.03,-24 78.03,-12 78.03,-12 78.03,-6 84.03,0 90.03,0 90.03,0 160.03,0 160.03,0 166.03,0 172.03,-6 172.03,-12 172.03,-12 172.03,-24 172.03,-24 172.03,-30 166.03,-36 160.03,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"125.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">generate_code</text>\n",
"</g>\n",
"<!-- decide_mode&#45;&gt;generate_code -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>decide_mode&#45;&gt;generate_code</title>\n",
"<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M319.3,-67.34C316.51,-66.52 313.74,-65.73 311.03,-65 257.02,-50.35 240.12,-50.93 182.7,-36.32\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"183.88,-32.75 173.32,-33.63 182.12,-39.53 183.88,-32.75\"/>\n",
"</g>\n",
"<!-- answer_question -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>answer_question</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M284.41,-36C284.41,-36 201.66,-36 201.66,-36 195.66,-36 189.66,-30 189.66,-24 189.66,-24 189.66,-12 189.66,-12 189.66,-6 195.66,0 201.66,0 201.66,0 284.41,0 284.41,0 290.41,0 296.41,-6 296.41,-12 296.41,-12 296.41,-24 296.41,-24 296.41,-30 290.41,-36 284.41,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"243.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">answer_question</text>\n",
"</g>\n",
"<!-- decide_mode&#45;&gt;answer_question -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>decide_mode&#45;&gt;answer_question</title>\n",
"<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M330.65,-64.62C316.76,-57.39 300.47,-48.91 285.73,-41.23\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"287.79,-37.84 277.31,-36.33 284.56,-44.05 287.79,-37.84\"/>\n",
"</g>\n",
"<!-- generate_poem -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>generate_poem</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M401.66,-36C401.66,-36 326.41,-36 326.41,-36 320.41,-36 314.41,-30 314.41,-24 314.41,-24 314.41,-12 314.41,-12 314.41,-6 320.41,0 326.41,0 326.41,0 401.66,0 401.66,0 407.66,0 413.66,-6 413.66,-12 413.66,-12 413.66,-24 413.66,-24 413.66,-30 407.66,-36 401.66,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"364.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">generate_poem</text>\n",
"</g>\n",
"<!-- decide_mode&#45;&gt;generate_poem -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>decide_mode&#45;&gt;generate_poem</title>\n",
"<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M364.03,-64.78C364.03,-59.23 364.03,-52.92 364.03,-46.8\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"367.53,-47.19 364.03,-37.19 360.53,-47.19 367.53,-47.19\"/>\n",
"</g>\n",
"<!-- prompt_for_more -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>prompt_for_more</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M532.78,-36C532.78,-36 443.28,-36 443.28,-36 437.28,-36 431.28,-30 431.28,-24 431.28,-24 431.28,-12 431.28,-12 431.28,-6 437.28,0 443.28,0 443.28,0 532.78,0 532.78,0 538.78,0 544.78,-6 544.78,-12 544.78,-12 544.78,-24 544.78,-24 544.78,-30 538.78,-36 532.78,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"488.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">prompt_for_more</text>\n",
"</g>\n",
"<!-- decide_mode&#45;&gt;prompt_for_more -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>decide_mode&#45;&gt;prompt_for_more</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M398.24,-64.62C412.47,-57.39 429.17,-48.91 444.28,-41.23\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"445.6,-43.98 452.93,-36.33 442.43,-37.73 445.6,-43.98\"/>\n",
"</g>\n",
"<!-- generate_code&#45;&gt;prompt -->\n",
"<g id=\"edge14\" class=\"edge\">\n",
"<title>generate_code&#45;&gt;prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M77.7,-27.27C55.95,-33.7 32.13,-45.04 19.03,-65 -5.59,-102.53 -6.88,-129.35 19.03,-166 39.52,-194.99 79.9,-205.83 110.24,-209.82\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"109.59,-213.38 119.91,-211 110.36,-206.42 109.59,-213.38\"/>\n",
"</g>\n",
"<!-- input__model -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>input__model</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"243.03\" cy=\"-83\" rx=\"58.52\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"243.03\" y=\"-77.95\" font-family=\"Times,serif\" font-size=\"14.00\">input: model</text>\n",
"</g>\n",
"<!-- input__model&#45;&gt;generate_code -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>input__model&#45;&gt;generate_code</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M214.77,-66.91C200.54,-59.31 183,-49.95 167.22,-41.52\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"169.13,-38.04 158.66,-36.42 165.83,-44.22 169.13,-38.04\"/>\n",
"</g>\n",
"<!-- input__model&#45;&gt;answer_question -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>input__model&#45;&gt;answer_question</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M243.03,-64.78C243.03,-59.23 243.03,-52.92 243.03,-46.8\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"246.53,-47.19 243.03,-37.19 239.53,-47.19 246.53,-47.19\"/>\n",
"</g>\n",
"<!-- input__model&#45;&gt;generate_poem -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>input__model&#45;&gt;generate_poem</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M272.01,-66.91C286.74,-59.24 304.93,-49.77 321.23,-41.28\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"322.32,-44.14 329.58,-36.42 319.09,-37.93 322.32,-44.14\"/>\n",
"</g>\n",
"<!-- answer_question&#45;&gt;prompt -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>answer_question&#45;&gt;prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M189.3,-33.9C186.51,-34.62 183.74,-35.32 181.03,-36 124.01,-50.22 88.75,-19.12 52.03,-65 19.18,-106.06 74.25,-158.3 113.94,-188.17\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"111.5,-191.45 121.63,-194.54 115.64,-185.81 111.5,-191.45\"/>\n",
"</g>\n",
"<!-- generate_poem&#45;&gt;prompt -->\n",
"<g id=\"edge13\" class=\"edge\">\n",
"<title>generate_poem&#45;&gt;prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M392.62,-36.42C402.23,-43.91 411.82,-53.58 417.03,-65 423.67,-79.56 425.92,-87.7 417.03,-101 365.41,-178.26 248.97,-201.81 187.98,-208.94\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"187.83,-205.54 178.26,-210.09 188.58,-212.5 187.83,-205.54\"/>\n",
"</g>\n",
"<!-- prompt_for_more&#45;&gt;prompt -->\n",
"<g id=\"edge15\" class=\"edge\">\n",
"<title>prompt_for_more&#45;&gt;prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M481.02,-36.29C467.23,-67.87 433.39,-134.5 382.03,-166 321.38,-203.2 236.63,-211.15 187.83,-212.38\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"188.05,-208.89 178.1,-212.56 188.16,-215.89 188.05,-208.89\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x121ae3490>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"app = streaming_application()\n",
"app.visualize()"
]
},
{
"cell_type": "markdown",
"id": "6f57c117-8951-4f98-86eb-ea1eac16347b",
"metadata": {},
"source": [
"# Calling the application\n",
"\n",
"With async streaming, we get back an `AsyncStreamingResultContainer`. This allows us to get partial results streaming in, while also allowing us to get the full result.\n",
"In the following case, we just "
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "eeb23b59-e207-47b1-b500-b60d431396f8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Alexander and Aaron, a duo renowned,\n",
"Their story in history forever bound,\n",
"In a duel they met,\n",
"A fate they couldn't forget,\n",
"A tragic end to a rivalry unsewn, unground.\n",
"\n",
"{'response': {'content': 'Alexander and Aaron, a duo renowned,\\n'\n",
" 'Their story in history forever bound,\\n'\n",
" 'In a duel they met,\\n'\n",
" \"A fate they couldn't forget,\\n\"\n",
" 'A tragic end to a rivalry unsewn, unground.',\n",
" 'role': 'assistant',\n",
" 'type': 'text'}}\n"
]
}
],
"source": [
"action, streaming_container = await app.astream_result(\n",
" halt_after=TERMINAL_ACTIONS, inputs={\"prompt\": \"Please generate a limerick about Alexander Hamilton and Aaron Burr\"}\n",
")\n",
"# Stream results in\n",
"async for item in streaming_container:\n",
" print(item['delta'], end=\"\")\n",
"\n",
"# Or just get the final result\n",
"result, state = await streaming_container.get()\n",
"print(\"\\n\")\n",
"pprint.pprint(result)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}