| { |
| "cells": [ |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "1ce738e6-14b4-4892-9f2d-a1db94cfb29a", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "!pip install burr[start]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "2118a2f4-d969-4771-81d3-156b432d1dc8", |
| "metadata": {}, |
| "source": [ |
| "# Streaming applications\n", |
| "\n", |
| "This shows how one goes about working with a streaming response with Burr, using FastAPI.\n", |
| "The code for implementation is in [application.py](application.py).\n", |
| "\n", |
| "This notebook only shows the streaming side. To check out FastAPI in Burr, check out\n", |
| "- The [Burr code](./application.py) -- imported and used here\n", |
| "- The [backend FastAPI server](./server.py) for the streaming output using SSE\n", |
| "- The [frontend typescript code](https://github.com/dagworks-inc/burr/blob/main/telemetry/ui/src/examples/StreamingChatbot.tsx) that renders and interacts with the stream\n", |
| "\n", |
| "You can view this demo in your app by running Burr:\n", |
| "\n", |
| "```bash\n", |
| "burr \n", |
| "```\n", |
| "\n", |
| "This will open a browser on [http://localhost:7241](http://localhost:7241)\n", |
| "\n", |
| "Then navigate to the [streaming example](http://localhost:7241/demos/streaming-chatbot)." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 18, |
| "id": "5b53de32-86af-475f-8976-e88a08986e34", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from application import application as streaming_application\n", |
| "from application import TERMINAL_ACTIONS\n", |
| "import pprint" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "675a4245-816f-4d43-b412-307aee83db8e", |
| "metadata": {}, |
| "source": [ |
| "# The application\n", |
| "\n", |
| "The application we created will be a simple chatbot proxy. It has a few diffrent modes -- it can either decide a prompt is \"unsafe\" (in this case meaning that it has the word \"unsafe\" in it, but this would typically go to specific model),\n", |
| "or do one of the following:\n", |
| "\n", |
| "1. Generate code\n", |
| "2. Answer a question\n", |
| "3. Generate a poem\n", |
| "4. Prompt for more\n", |
| "\n", |
| "It will use an LLM to decide which to do. It streams back text using async streaming in Burr. Read more about how that is implemented [here](https://burr.dagworks.io/concepts/streaming-actions/).\n", |
| "\n", |
| "Note that, even though not every response is streaming (E.G. unsafe response, which is hardcoded), they are modeled as streaming to make interaction with the app simpler." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "81799877-8f0f-4572-9a96-f6ce6c430d9a", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 8.1.0 (20230707.0739)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"553pt\" height=\"304pt\"\n", |
| " viewBox=\"0.00 0.00 552.78 304.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 300)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-300 548.78,-300 548.78,4 -4,4\"/>\n", |
| "<!-- prompt -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>prompt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M164.91,-231C164.91,-231 133.16,-231 133.16,-231 127.16,-231 121.16,-225 121.16,-219 121.16,-219 121.16,-207 121.16,-207 121.16,-201 127.16,-195 133.16,-195 133.16,-195 164.91,-195 164.91,-195 170.91,-195 176.91,-201 176.91,-207 176.91,-207 176.91,-219 176.91,-219 176.91,-225 170.91,-231 164.91,-231\"/>\n", |
| "<text text-anchor=\"middle\" x=\"149.03\" y=\"-207.95\" font-family=\"Times,serif\" font-size=\"14.00\">prompt</text>\n", |
| "</g>\n", |
| "<!-- check_safety -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>check_safety</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M216.28,-166C216.28,-166 153.78,-166 153.78,-166 147.78,-166 141.78,-160 141.78,-154 141.78,-154 141.78,-142 141.78,-142 141.78,-136 147.78,-130 153.78,-130 153.78,-130 216.28,-130 216.28,-130 222.28,-130 228.28,-136 228.28,-142 228.28,-142 228.28,-154 228.28,-154 228.28,-160 222.28,-166 216.28,-166\"/>\n", |
| "<text text-anchor=\"middle\" x=\"185.03\" y=\"-142.95\" font-family=\"Times,serif\" font-size=\"14.00\">check_safety</text>\n", |
| "</g>\n", |
| "<!-- prompt->check_safety -->\n", |
| "<g id=\"edge5\" class=\"edge\">\n", |
| "<title>prompt->check_safety</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M158.87,-194.78C162.28,-188.81 166.19,-181.97 169.93,-175.43\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"173.28,-177.61 175.21,-167.19 167.21,-174.14 173.28,-177.61\"/>\n", |
| "</g>\n", |
| "<!-- input__prompt -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>input__prompt</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"149.03\" cy=\"-278\" rx=\"62.1\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"149.03\" y=\"-272.95\" font-family=\"Times,serif\" font-size=\"14.00\">input: prompt</text>\n", |
| "</g>\n", |
| "<!-- input__prompt->prompt -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>input__prompt->prompt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M149.03,-259.78C149.03,-254.23 149.03,-247.92 149.03,-241.8\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"152.53,-242.19 149.03,-232.19 145.53,-242.19 152.53,-242.19\"/>\n", |
| "</g>\n", |
| "<!-- unsafe_response -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>unsafe_response</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M154.66,-101C154.66,-101 73.41,-101 73.41,-101 67.41,-101 61.41,-95 61.41,-89 61.41,-89 61.41,-77 61.41,-77 61.41,-71 67.41,-65 73.41,-65 73.41,-65 154.66,-65 154.66,-65 160.66,-65 166.66,-71 166.66,-77 166.66,-77 166.66,-89 166.66,-89 166.66,-95 160.66,-101 154.66,-101\"/>\n", |
| "<text text-anchor=\"middle\" x=\"114.03\" y=\"-77.95\" font-family=\"Times,serif\" font-size=\"14.00\">unsafe_response</text>\n", |
| "</g>\n", |
| "<!-- check_safety->unsafe_response -->\n", |
| "<g id=\"edge7\" class=\"edge\">\n", |
| "<title>check_safety->unsafe_response</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M165.63,-129.78C158.2,-123.19 149.56,-115.53 141.51,-108.39\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"144.21,-106.21 134.41,-102.19 139.57,-111.45 144.21,-106.21\"/>\n", |
| "</g>\n", |
| "<!-- decide_mode -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>decide_mode</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M396.41,-101C396.41,-101 331.66,-101 331.66,-101 325.66,-101 319.66,-95 319.66,-89 319.66,-89 319.66,-77 319.66,-77 319.66,-71 325.66,-65 331.66,-65 331.66,-65 396.41,-65 396.41,-65 402.41,-65 408.41,-71 408.41,-77 408.41,-77 408.41,-89 408.41,-89 408.41,-95 402.41,-101 396.41,-101\"/>\n", |
| "<text text-anchor=\"middle\" x=\"364.03\" y=\"-77.95\" font-family=\"Times,serif\" font-size=\"14.00\">decide_mode</text>\n", |
| "</g>\n", |
| "<!-- check_safety->decide_mode -->\n", |
| "<g id=\"edge6\" class=\"edge\">\n", |
| "<title>check_safety->decide_mode</title>\n", |
| "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M228.36,-131.75C252.66,-123.2 283.21,-112.45 309.16,-103.31\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"309.96,-106.39 318.23,-99.77 307.64,-99.79 309.96,-106.39\"/>\n", |
| "</g>\n", |
| "<!-- unsafe_response->prompt -->\n", |
| "<g id=\"edge16\" class=\"edge\">\n", |
| "<title>unsafe_response->prompt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M117.5,-101.42C120.96,-118.11 126.66,-143.96 133.03,-166 134.76,-171.99 136.82,-178.35 138.88,-184.36\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"135.89,-185.4 142.51,-193.67 142.49,-183.07 135.89,-185.4\"/>\n", |
| "</g>\n", |
| "<!-- generate_code -->\n", |
| "<g id=\"node6\" class=\"node\">\n", |
| "<title>generate_code</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M160.03,-36C160.03,-36 90.03,-36 90.03,-36 84.03,-36 78.03,-30 78.03,-24 78.03,-24 78.03,-12 78.03,-12 78.03,-6 84.03,0 90.03,0 90.03,0 160.03,0 160.03,0 166.03,0 172.03,-6 172.03,-12 172.03,-12 172.03,-24 172.03,-24 172.03,-30 166.03,-36 160.03,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"125.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">generate_code</text>\n", |
| "</g>\n", |
| "<!-- decide_mode->generate_code -->\n", |
| "<g id=\"edge8\" class=\"edge\">\n", |
| "<title>decide_mode->generate_code</title>\n", |
| "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M319.3,-67.34C316.51,-66.52 313.74,-65.73 311.03,-65 257.02,-50.35 240.12,-50.93 182.7,-36.32\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"183.88,-32.75 173.32,-33.63 182.12,-39.53 183.88,-32.75\"/>\n", |
| "</g>\n", |
| "<!-- answer_question -->\n", |
| "<g id=\"node8\" class=\"node\">\n", |
| "<title>answer_question</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M284.41,-36C284.41,-36 201.66,-36 201.66,-36 195.66,-36 189.66,-30 189.66,-24 189.66,-24 189.66,-12 189.66,-12 189.66,-6 195.66,0 201.66,0 201.66,0 284.41,0 284.41,0 290.41,0 296.41,-6 296.41,-12 296.41,-12 296.41,-24 296.41,-24 296.41,-30 290.41,-36 284.41,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"243.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">answer_question</text>\n", |
| "</g>\n", |
| "<!-- decide_mode->answer_question -->\n", |
| "<g id=\"edge9\" class=\"edge\">\n", |
| "<title>decide_mode->answer_question</title>\n", |
| "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M330.65,-64.62C316.76,-57.39 300.47,-48.91 285.73,-41.23\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"287.79,-37.84 277.31,-36.33 284.56,-44.05 287.79,-37.84\"/>\n", |
| "</g>\n", |
| "<!-- generate_poem -->\n", |
| "<g id=\"node9\" class=\"node\">\n", |
| "<title>generate_poem</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M401.66,-36C401.66,-36 326.41,-36 326.41,-36 320.41,-36 314.41,-30 314.41,-24 314.41,-24 314.41,-12 314.41,-12 314.41,-6 320.41,0 326.41,0 326.41,0 401.66,0 401.66,0 407.66,0 413.66,-6 413.66,-12 413.66,-12 413.66,-24 413.66,-24 413.66,-30 407.66,-36 401.66,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"364.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">generate_poem</text>\n", |
| "</g>\n", |
| "<!-- decide_mode->generate_poem -->\n", |
| "<g id=\"edge10\" class=\"edge\">\n", |
| "<title>decide_mode->generate_poem</title>\n", |
| "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M364.03,-64.78C364.03,-59.23 364.03,-52.92 364.03,-46.8\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"367.53,-47.19 364.03,-37.19 360.53,-47.19 367.53,-47.19\"/>\n", |
| "</g>\n", |
| "<!-- prompt_for_more -->\n", |
| "<g id=\"node10\" class=\"node\">\n", |
| "<title>prompt_for_more</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M532.78,-36C532.78,-36 443.28,-36 443.28,-36 437.28,-36 431.28,-30 431.28,-24 431.28,-24 431.28,-12 431.28,-12 431.28,-6 437.28,0 443.28,0 443.28,0 532.78,0 532.78,0 538.78,0 544.78,-6 544.78,-12 544.78,-12 544.78,-24 544.78,-24 544.78,-30 538.78,-36 532.78,-36\"/>\n", |
| "<text text-anchor=\"middle\" x=\"488.03\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">prompt_for_more</text>\n", |
| "</g>\n", |
| "<!-- decide_mode->prompt_for_more -->\n", |
| "<g id=\"edge11\" class=\"edge\">\n", |
| "<title>decide_mode->prompt_for_more</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M398.24,-64.62C412.47,-57.39 429.17,-48.91 444.28,-41.23\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"445.6,-43.98 452.93,-36.33 442.43,-37.73 445.6,-43.98\"/>\n", |
| "</g>\n", |
| "<!-- generate_code->prompt -->\n", |
| "<g id=\"edge14\" class=\"edge\">\n", |
| "<title>generate_code->prompt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M77.7,-27.27C55.95,-33.7 32.13,-45.04 19.03,-65 -5.59,-102.53 -6.88,-129.35 19.03,-166 39.52,-194.99 79.9,-205.83 110.24,-209.82\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"109.59,-213.38 119.91,-211 110.36,-206.42 109.59,-213.38\"/>\n", |
| "</g>\n", |
| "<!-- input__model -->\n", |
| "<g id=\"node7\" class=\"node\">\n", |
| "<title>input__model</title>\n", |
| "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"243.03\" cy=\"-83\" rx=\"58.52\" ry=\"18\"/>\n", |
| "<text text-anchor=\"middle\" x=\"243.03\" y=\"-77.95\" font-family=\"Times,serif\" font-size=\"14.00\">input: model</text>\n", |
| "</g>\n", |
| "<!-- input__model->generate_code -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>input__model->generate_code</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M214.77,-66.91C200.54,-59.31 183,-49.95 167.22,-41.52\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"169.13,-38.04 158.66,-36.42 165.83,-44.22 169.13,-38.04\"/>\n", |
| "</g>\n", |
| "<!-- input__model->answer_question -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>input__model->answer_question</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M243.03,-64.78C243.03,-59.23 243.03,-52.92 243.03,-46.8\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"246.53,-47.19 243.03,-37.19 239.53,-47.19 246.53,-47.19\"/>\n", |
| "</g>\n", |
| "<!-- input__model->generate_poem -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>input__model->generate_poem</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M272.01,-66.91C286.74,-59.24 304.93,-49.77 321.23,-41.28\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"322.32,-44.14 329.58,-36.42 319.09,-37.93 322.32,-44.14\"/>\n", |
| "</g>\n", |
| "<!-- answer_question->prompt -->\n", |
| "<g id=\"edge12\" class=\"edge\">\n", |
| "<title>answer_question->prompt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M189.3,-33.9C186.51,-34.62 183.74,-35.32 181.03,-36 124.01,-50.22 88.75,-19.12 52.03,-65 19.18,-106.06 74.25,-158.3 113.94,-188.17\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"111.5,-191.45 121.63,-194.54 115.64,-185.81 111.5,-191.45\"/>\n", |
| "</g>\n", |
| "<!-- generate_poem->prompt -->\n", |
| "<g id=\"edge13\" class=\"edge\">\n", |
| "<title>generate_poem->prompt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M392.62,-36.42C402.23,-43.91 411.82,-53.58 417.03,-65 423.67,-79.56 425.92,-87.7 417.03,-101 365.41,-178.26 248.97,-201.81 187.98,-208.94\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"187.83,-205.54 178.26,-210.09 188.58,-212.5 187.83,-205.54\"/>\n", |
| "</g>\n", |
| "<!-- prompt_for_more->prompt -->\n", |
| "<g id=\"edge15\" class=\"edge\">\n", |
| "<title>prompt_for_more->prompt</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M481.02,-36.29C467.23,-67.87 433.39,-134.5 382.03,-166 321.38,-203.2 236.63,-211.15 187.83,-212.38\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"188.05,-208.89 178.1,-212.56 188.16,-215.89 188.05,-208.89\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<graphviz.graphs.Digraph at 0x121ae3490>" |
| ] |
| }, |
| "execution_count": 7, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "app = streaming_application()\n", |
| "app.visualize()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "6f57c117-8951-4f98-86eb-ea1eac16347b", |
| "metadata": {}, |
| "source": [ |
| "# Calling the application\n", |
| "\n", |
| "With async streaming, we get back an `AsyncStreamingResultContainer`. This allows us to get partial results streaming in, while also allowing us to get the full result.\n", |
| "In the following case, we just " |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 19, |
| "id": "eeb23b59-e207-47b1-b500-b60d431396f8", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Alexander and Aaron, a duo renowned,\n", |
| "Their story in history forever bound,\n", |
| "In a duel they met,\n", |
| "A fate they couldn't forget,\n", |
| "A tragic end to a rivalry unsewn, unground.\n", |
| "\n", |
| "{'response': {'content': 'Alexander and Aaron, a duo renowned,\\n'\n", |
| " 'Their story in history forever bound,\\n'\n", |
| " 'In a duel they met,\\n'\n", |
| " \"A fate they couldn't forget,\\n\"\n", |
| " 'A tragic end to a rivalry unsewn, unground.',\n", |
| " 'role': 'assistant',\n", |
| " 'type': 'text'}}\n" |
| ] |
| } |
| ], |
| "source": [ |
| "action, streaming_container = await app.astream_result(\n", |
| " halt_after=TERMINAL_ACTIONS, inputs={\"prompt\": \"Please generate a limerick about Alexander Hamilton and Aaron Burr\"}\n", |
| ")\n", |
| "# Stream results in\n", |
| "async for item in streaming_container:\n", |
| " print(item['delta'], end=\"\")\n", |
| "\n", |
| "# Or just get the final result\n", |
| "result, state = await streaming_container.get()\n", |
| "print(\"\\n\")\n", |
| "pprint.pprint(result)" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.11.6" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |