| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "f4b744ec-ce8d-4e6b-b818-d86f6a869028", |
| "metadata": {}, |
| "source": [ |
| "<a target=\"_blank\" href=\"https://colab.research.google.com/github/dagworks-inc/burr/blob/main/examples/parallelism/notebook.ipynb\">\n", |
| " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n", |
| "</a> \n", |
| "or <a target=\"_blank\" href=\"https://www.github.com/dagworks-inc/burr/tree/main/examples/parallelism/notebook.ipynb\">view source</a>\n", |
| "\n", |
| "For a video walkthrough of this notebook <a href=\"https://youtu.be/G7lw63IBSmY\">click here</a>." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "5bbca346-710f-42ad-80dc-7e7d2c2f57ee", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# install some dependencies and a few more\n", |
| "%pip install \"burr[start,opentelemetry]\" opentelemetry-instrumentation-openai openai anthropic" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "7a4c338b-90dc-4866-9cc9-5c9af38f9537", |
| "metadata": {}, |
| "source": [ |
| "# Parallelism (& hierarchy)\n", |
| "\n", |
| "Burr is all about thinking and modeling things as a \"graph\" or \"flowchart\". This just so happens to be a great way to model agents. This modeling also lends itself well to modeling \"hierarchy/recursion\" and \"parallelism\", which are key in building more complex interactions and agent systems.\n", |
| "\n", |
| "In this notebook we're going to go over how to run parallel \"actions\", which could be as simple as a single \"action\", or even \"whole burr sub-applications themselves\" (a.k.a. \"sub-agents\"). For full documentation on Parallelism see [this page](https://burr.dagworks.io/concepts/parallelism/); if you're familliar with the [`map-reduce`](https://en.wikipedia.org/wiki/MapReduce) pattern, then you'll feel right at home here.\n", |
| "\n", |
| "We will start simple and show how to write a simple Burr application that compares different LLMs for the same prompt & context. We'll then extend what we parallelize to be a whole burr sub-application/agent. \n", |
| "\n", |
| "To start:\n", |
| "1. we will not use Burr's parallelism or recursion constructs to help build your mental model.\n", |
| "2. we will then show Burr's [\"recursion/hierarchy\" capabilities](https://burr.dagworks.io/concepts/recursion/).\n", |
| "3. we will then show how Burr's parallelism constructs simplify things.\n", |
| "4. we will then show how Burr's parallelism constructs enable you to more easily model more complex behavior.\n", |
| "\n", |
| "Throughout we will show how the self-hostable Burr UI helps us capture telemetry and understand what's going on." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "8e1b0f79-81ac-48d8-ac61-a175fa6bd8b9", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# set your OpenAI API key\n", |
| "import os\n", |
| "os.environ[\"OPENAI_API_KEY\"] = \"...\"" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "10459bdd-2154-43c6-84d2-7bee4825ab9b", |
| "metadata": {}, |
| "source": [ |
| "# Imports\n", |
| "What we need for this whole notebook" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "id": "e95aa6fe-1eaa-490d-b7fd-ff956306bcc5", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from concurrent.futures import ThreadPoolExecutor\n", |
| "from typing import Callable, Generator, List, Dict, Any\n", |
| "\n", |
| "# for displaying in the notebook\n", |
| "from IPython.display import HTML, IFrame\n", |
| "import pprint\n", |
| "import openai\n", |
| "\n", |
| "# burr imports\n", |
| "from burr.core import action, State, ApplicationBuilder, ApplicationContext, Action\n", |
| "from burr.core.parallelism import MapStates, RunnableGraph\n", |
| "\n", |
| "# instrumentation using opentelemetry\n", |
| "from opentelemetry.instrumentation.openai import OpenAIInstrumentor\n", |
| "OpenAIInstrumentor().instrument()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "e59f1418-68f8-47c2-8034-4b928211acbe", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "The burr.integrations.notebook extension is already loaded. To reload it, use:\n", |
| " %reload_ext burr.integrations.notebook\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "\n", |
| " <iframe\n", |
| " width=\"100%\"\n", |
| " height=\"400\"\n", |
| " src=\"http://127.0.0.1:7241\"\n", |
| " frameborder=\"0\"\n", |
| " allowfullscreen\n", |
| " \n", |
| " ></iframe>\n", |
| " " |
| ], |
| "text/plain": [ |
| "<IPython.lib.display.IFrame at 0x10f518310>" |
| ] |
| }, |
| "metadata": {}, |
| "output_type": "display_data" |
| } |
| ], |
| "source": [ |
| "# let's load up some jupyter UI magics\n", |
| "%load_ext burr.integrations.notebook\n", |
| "%burr_ui \n", |
| "# this starts the UI in a background process" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "77bc7268-26ec-462c-9219-c7eede3bd325", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# if in google collab we need to expose the port\n", |
| "from google.colab import output\n", |
| "output.serve_kernel_port_as_window(7241) # can expose as a new tab -- take note of the URL and provide it below." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "f3e74f3e-db80-436c-ac67-10078e7965de", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# can expose as an iframe directly\n", |
| "output.serve_kernel_port_as_iframe(7241) " |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "4778dbb2-75a0-4490-86b8-7be206832d92", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "URL_FOR_UI = \"http://localhost:7241\" # this is the default -- replace with google collab domain (without trailing /)\n", |
| "# e.g. URL_FOR_UI = \"https://cg2cjmb1mmu-496ff2e9c6d22116-7241-colab.googleusercontent.com\"" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "948a51e7-75e8-4d47-a23f-c24a89e8be60", |
| "metadata": {}, |
| "source": [ |
| "# Actions\n", |
| "\n", |
| "In this next code block we'll set up the base actions that we'll use to construct our Burr applications. That is, the things we'll want to run in parallel.\n", |
| "\n", |
| "Our initial Burr application will have a simple (3-action/node) structure:\n", |
| "\n", |
| "1. `process_input` -- pulls data in from the user at runtime\n", |
| "2. `run_llms` -- runs the different LLM models in parallel (not defined below)\n", |
| "3. `join_outputs` -- writes a map of LLM -> result\n", |
| "\n", |
| "We will first define actions (1) and (3), then redefine (2) according to whether we're using recursion/hierarchy or parallelism." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "0998576a-c8cb-45a2-94ce-d8ce326e3be7", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def _query_llm(model: str, prompt: str) -> str:\n", |
| " \"\"\"Simple function to query our LLM -- we use OpenAI here / swap for your favorite model\"\"\"\n", |
| " client = openai.Client()\n", |
| " response = client.chat.completions.create(\n", |
| " model=model,\n", |
| " messages=[{\"role\": \"user\", \"content\": prompt}]\n", |
| " )\n", |
| " return response.choices[0].message.content\n", |
| "\n", |
| "\n", |
| "@action(reads=[], writes=[\"prompt\", \"models\"])\n", |
| "def process_input(state: State, prompt: str, models: List[str]) -> State:\n", |
| " \"\"\"First Action (node) in our graph -- will take the prompt input and write to state\"\"\"\n", |
| " return state.update(prompt=prompt, models=models)\n", |
| "\n", |
| "\n", |
| "@action(reads=[\"responses\", \"models\"], writes=[\"all_responses\"])\n", |
| "def join_outputs(state: State) -> State:\n", |
| " \"\"\"Final action (node), just joins in a dictionary\"\"\"\n", |
| " joined_results = {}\n", |
| " for response, model in zip(state[\"responses\"], state[\"models\"]):\n", |
| " joined_results[model] = response\n", |
| " return state.update(all_responses=joined_results)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "971e9a62-543a-4866-854c-60bbe89039ec", |
| "metadata": {}, |
| "source": [ |
| "# Approach #1 -- doing this manually\n", |
| "\n", |
| "Nothing stops you from running multiple queries in the same Burr Action. In this we're going to just define the middle action to run in parallel -- it will launch an application in Burr, and wait for the results.\n", |
| "\n", |
| "This is the \"manual\" approach -- how you might achieve parallelism without Burr. In it, we're going to do the following:\n", |
| "\n", |
| "1. Create a Burr sub-application composed of a single action (node) that queries the LLM\n", |
| "2. Run multiple variants of it in parallel\n", |
| "3. Join the result\n", |
| "\n", |
| "This is a useful approach as it is simple and looks like standard python code. That said, it lacks in visibility -- you have no way to know what the sub-LLM calls are doing.\n", |
| "\n", |
| "First, we'll define the middle action. Next, we'll define the application. Then we can run!" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "id": "27ee761e-f7e8-4da7-82f2-349653cc6a7e", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 12.0.0 (20240704.0754)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"230pt\" height=\"241pt\"\n", |
| " viewBox=\"0.00 0.00 229.85 241.40\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 237.4)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-237.4 225.85,-237.4 225.85,4 -4,4\"/>\n", |
| "<!-- process_input -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>process_input</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M152.1,-167.8C152.1,-167.8 70.5,-167.8 70.5,-167.8 64.5,-167.8 58.5,-161.8 58.5,-155.8 58.5,-155.8 58.5,-143.2 58.5,-143.2 58.5,-137.2 64.5,-131.2 70.5,-131.2 70.5,-131.2 152.1,-131.2 152.1,-131.2 158.1,-131.2 164.1,-137.2 164.1,-143.2 164.1,-143.2 164.1,-155.8 164.1,-155.8 164.1,-161.8 158.1,-167.8 152.1,-167.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-143.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">process_input</text>\n", |
| "</g>\n", |
| "<!-- query_multiple_models -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>query_multiple_models</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M180.22,-102.2C180.22,-102.2 42.38,-102.2 42.38,-102.2 36.38,-102.2 30.38,-96.2 30.38,-90.2 30.38,-90.2 30.38,-77.6 30.38,-77.6 30.38,-71.6 36.38,-65.6 42.38,-65.6 42.38,-65.6 180.22,-65.6 180.22,-65.6 186.22,-65.6 192.22,-71.6 192.22,-77.6 192.22,-77.6 192.22,-90.2 192.22,-90.2 192.22,-96.2 186.22,-102.2 180.22,-102.2\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-78.1\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">query_multiple_models</text>\n", |
| "</g>\n", |
| "<!-- process_input->query_multiple_models -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>process_input->query_multiple_models</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M111.3,-130.78C111.3,-125.55 111.3,-119.67 111.3,-113.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"114.8,-114.03 111.3,-104.03 107.8,-114.03 114.8,-114.03\"/>\n", |
| "</g>\n", |
| "<!-- input__models -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>input__models</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"102.6,-233.4 0,-233.4 0,-196.8 102.6,-196.8 102.6,-233.4\"/>\n", |
| "<text text-anchor=\"middle\" x=\"51.3\" y=\"-209.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input: models</text>\n", |
| "</g>\n", |
| "<!-- input__models->process_input -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>input__models->process_input</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M68.01,-196.38C73.86,-190.18 80.56,-183.08 86.91,-176.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"89.32,-178.9 93.64,-169.22 84.23,-174.09 89.32,-178.9\"/>\n", |
| "</g>\n", |
| "<!-- input__prompt -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>input__prompt</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"221.85,-233.4 120.75,-233.4 120.75,-196.8 221.85,-196.8 221.85,-233.4\"/>\n", |
| "<text text-anchor=\"middle\" x=\"171.3\" y=\"-209.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input: prompt</text>\n", |
| "</g>\n", |
| "<!-- input__prompt->process_input -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>input__prompt->process_input</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M154.59,-196.38C148.74,-190.18 142.04,-183.08 135.69,-176.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"138.37,-174.09 128.96,-169.22 133.28,-178.9 138.37,-174.09\"/>\n", |
| "</g>\n", |
| "<!-- join_outputs -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>join_outputs</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.47,-36.6C146.47,-36.6 76.12,-36.6 76.12,-36.6 70.12,-36.6 64.12,-30.6 64.12,-24.6 64.12,-24.6 64.12,-12 64.12,-12 64.12,-6 70.12,0 76.12,0 76.12,0 146.47,0 146.47,0 152.47,0 158.47,-6 158.47,-12 158.47,-12 158.47,-24.6 158.47,-24.6 158.47,-30.6 152.47,-36.6 146.47,-36.6\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">join_outputs</text>\n", |
| "</g>\n", |
| "<!-- query_multiple_models->join_outputs -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>query_multiple_models->join_outputs</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M111.3,-65.18C111.3,-59.95 111.3,-54.07 111.3,-48.32\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"114.8,-48.43 111.3,-38.43 107.8,-48.43 114.8,-48.43\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<burr.core.application.Application at 0x10f43b490>" |
| ] |
| }, |
| "execution_count": 6, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "@action(reads=[\"models\", \"prompt\"], writes=[\"responses\"])\n", |
| "def query_multiple_models(state: State) -> State:\n", |
| " \"\"\"Query multiple models in parallel and store the results in the state.\"\"\"\n", |
| " models: List[str] = state[\"models\"]\n", |
| " prompt: str = state[\"prompt\"]\n", |
| "\n", |
| " def query_model(model: str) -> str:\n", |
| " return _query_llm(model, prompt)\n", |
| "\n", |
| " with ThreadPoolExecutor() as executor:\n", |
| " futures = {executor.submit(query_model, model): model for model in models}\n", |
| " results = {}\n", |
| " for future in futures:\n", |
| " model = futures[future]\n", |
| " results[model] = future.result()\n", |
| " \n", |
| " return state.update(responses=results)\n", |
| "\n", |
| "app = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions(\n", |
| " process_input, # define in fircell above\n", |
| " query_multiple_models,\n", |
| " join_outputs\n", |
| " ).with_entrypoint(\"process_input\")\n", |
| " .with_tracker(project=\"parallelism_example\", use_otel_tracing=True)\n", |
| " .with_transitions(\n", |
| " (\"process_input\", \"query_multiple_models\"),\n", |
| " (\"query_multiple_models\", \"join_outputs\")\n", |
| " )\n", |
| " .build()\n", |
| ")\n", |
| "app" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "c99f7a75-1410-4aa6-b60d-16a592c3ae5f", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/html": [ |
| "\n", |
| " <iframe\n", |
| " width=\"100%\"\n", |
| " height=\"700px\"\n", |
| " src=\"http://localhost:7241/project/parallelism_example/null/77783ee1-dcf5-4991-9b27-c92313fea7a6\"\n", |
| " frameborder=\"0\"\n", |
| " allowfullscreen\n", |
| " \n", |
| " ></iframe>\n", |
| " " |
| ], |
| "text/plain": [ |
| "<IPython.lib.display.IFrame at 0x129ffca30>" |
| ] |
| }, |
| "execution_count": 7, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# setup UI to see the run - hit the live button to see it run live\n", |
| "IFrame(f'{URL_FOR_UI}/project/parallelism_example/null/{app.uid}', width=\"100%\", height=\"700px\")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "ab7e1041-fd7e-438e-8a78-aa468830e2f6", |
| "metadata": { |
| "scrolled": true |
| }, |
| "outputs": [], |
| "source": [ |
| "# let's run the application\n", |
| "action_, _, state = app.run(\n", |
| " inputs={\"prompt\": \"what is the meaning of life?\", \"models\" : [\"gpt-4\", \"gpt-4-turbo\", \"gpt-3.5-turbo\"]},\n", |
| " halt_after=[\"join_outputs\"]\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "becfab1b-f89d-49a0-b1af-2f01359bfefa", |
| "metadata": {}, |
| "source": [ |
| "# Approach #2 -- using Burr within Burr, i.e. recursion/hierarchy\n", |
| "\n", |
| "Burr allows you to create Burr applications within Burr applications and wire through tracking to the UI so you can visualize sub-applications. For this, we will be representing the LLM calls with different models as their own (single-node) application. While this is built for more complex sub-application shapes, and is a bit of overkill for our toy example, it still works quite well with the simplicity of a single action (node) application. My point here is that you could add more actions to the application very easily and do more complex operations and model hierarchy than what we're doing here. The Burr sub-applications/graphs can be used independently elsewhere.\n", |
| "\n", |
| "For more on the code constructs in this section see Burr's [recursive/hierarchy tracking capabilities](https://burr.dagworks.io/concepts/recursion/)." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "id": "14a89b97-ae8f-4ae7-b1fc-a91182d5658d", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 12.0.0 (20240704.0754)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"230pt\" height=\"241pt\"\n", |
| " viewBox=\"0.00 0.00 229.85 241.40\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 237.4)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-237.4 225.85,-237.4 225.85,4 -4,4\"/>\n", |
| "<!-- process_input -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>process_input</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M152.1,-167.8C152.1,-167.8 70.5,-167.8 70.5,-167.8 64.5,-167.8 58.5,-161.8 58.5,-155.8 58.5,-155.8 58.5,-143.2 58.5,-143.2 58.5,-137.2 64.5,-131.2 70.5,-131.2 70.5,-131.2 152.1,-131.2 152.1,-131.2 158.1,-131.2 164.1,-137.2 164.1,-143.2 164.1,-143.2 164.1,-155.8 164.1,-155.8 164.1,-161.8 158.1,-167.8 152.1,-167.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-143.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">process_input</text>\n", |
| "</g>\n", |
| "<!-- query_multiple_models -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>query_multiple_models</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M180.22,-102.2C180.22,-102.2 42.38,-102.2 42.38,-102.2 36.38,-102.2 30.38,-96.2 30.38,-90.2 30.38,-90.2 30.38,-77.6 30.38,-77.6 30.38,-71.6 36.38,-65.6 42.38,-65.6 42.38,-65.6 180.22,-65.6 180.22,-65.6 186.22,-65.6 192.22,-71.6 192.22,-77.6 192.22,-77.6 192.22,-90.2 192.22,-90.2 192.22,-96.2 186.22,-102.2 180.22,-102.2\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-78.1\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">query_multiple_models</text>\n", |
| "</g>\n", |
| "<!-- process_input->query_multiple_models -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>process_input->query_multiple_models</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M111.3,-130.78C111.3,-125.55 111.3,-119.67 111.3,-113.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"114.8,-114.03 111.3,-104.03 107.8,-114.03 114.8,-114.03\"/>\n", |
| "</g>\n", |
| "<!-- input__models -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>input__models</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"102.6,-233.4 0,-233.4 0,-196.8 102.6,-196.8 102.6,-233.4\"/>\n", |
| "<text text-anchor=\"middle\" x=\"51.3\" y=\"-209.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input: models</text>\n", |
| "</g>\n", |
| "<!-- input__models->process_input -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>input__models->process_input</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M68.01,-196.38C73.86,-190.18 80.56,-183.08 86.91,-176.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"89.32,-178.9 93.64,-169.22 84.23,-174.09 89.32,-178.9\"/>\n", |
| "</g>\n", |
| "<!-- input__prompt -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>input__prompt</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"221.85,-233.4 120.75,-233.4 120.75,-196.8 221.85,-196.8 221.85,-233.4\"/>\n", |
| "<text text-anchor=\"middle\" x=\"171.3\" y=\"-209.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input: prompt</text>\n", |
| "</g>\n", |
| "<!-- input__prompt->process_input -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>input__prompt->process_input</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M154.59,-196.38C148.74,-190.18 142.04,-183.08 135.69,-176.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"138.37,-174.09 128.96,-169.22 133.28,-178.9 138.37,-174.09\"/>\n", |
| "</g>\n", |
| "<!-- join_outputs -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>join_outputs</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.47,-36.6C146.47,-36.6 76.12,-36.6 76.12,-36.6 70.12,-36.6 64.12,-30.6 64.12,-24.6 64.12,-24.6 64.12,-12 64.12,-12 64.12,-6 70.12,0 76.12,0 76.12,0 146.47,0 146.47,0 152.47,0 158.47,-6 158.47,-12 158.47,-12 158.47,-24.6 158.47,-24.6 158.47,-30.6 152.47,-36.6 146.47,-36.6\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">join_outputs</text>\n", |
| "</g>\n", |
| "<!-- query_multiple_models->join_outputs -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>query_multiple_models->join_outputs</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M111.3,-65.18C111.3,-59.95 111.3,-54.07 111.3,-48.32\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"114.8,-48.43 111.3,-38.43 107.8,-48.43 114.8,-48.43\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<burr.core.application.Application at 0x129ffef80>" |
| ] |
| }, |
| "execution_count": 9, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "@action(reads=[\"model\", \"prompt\"], writes=[\"response\"])\n", |
| "def inner_query_model(state: State) -> State:\n", |
| " return state.update(response=_query_llm(state[\"model\"], state[\"prompt\"]))\n", |
| " \n", |
| "@action(reads=[\"models\", \"prompt\"], writes=[\"responses\"])\n", |
| "def query_multiple_models(state: State, __context: ApplicationContext) -> State:\n", |
| " \"\"\"Query multiple models in parallel and store the results in the state.\"\"\"\n", |
| " apps = []\n", |
| " for model in state[\"models\"]:\n", |
| " sub_app = (# it's just a single action here, but it could be more!\n", |
| " ApplicationBuilder()\n", |
| " .with_actions(query_model=inner_query_model)\n", |
| " .with_state(model=model, prompt=state[\"prompt\"])\n", |
| " .with_tracker(project=\"parallelism_example\", use_otel_tracing=True)\n", |
| " .with_spawning_parent(# we link the parent and subapplication together\n", |
| " app_id=__context.app_id,\n", |
| " sequence_id=__context.sequence_id\n", |
| " )\n", |
| " .with_entrypoint(\"query_model\")\n", |
| " .build()\n", |
| " )\n", |
| " apps.append((model, sub_app))\n", |
| " \n", |
| " with ThreadPoolExecutor() as executor:\n", |
| " futures = {executor.submit(app.run, halt_after=[\"query_model\"]): model for model, app in apps}\n", |
| " results = {}\n", |
| " for future in futures:\n", |
| " model = futures[future]\n", |
| " results[model] = future.result()\n", |
| " \n", |
| " return state.update(responses=results)\n", |
| "\n", |
| "app_recursion = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions(\n", |
| " process_input,\n", |
| " query_multiple_models,\n", |
| " join_outputs\n", |
| " ).with_entrypoint(\"process_input\")\n", |
| " .with_tracker(project=\"parallelism_example\", use_otel_tracing=True)\n", |
| " .with_transitions(\n", |
| " (\"process_input\", \"query_multiple_models\"),\n", |
| " (\"query_multiple_models\", \"join_outputs\")\n", |
| " )\n", |
| " .build()\n", |
| ")\n", |
| "app_recursion" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "adb02469-9dd6-404e-8a9f-0e50fe21123f", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/html": [ |
| "\n", |
| " <iframe\n", |
| " width=\"100%\"\n", |
| " height=\"700px\"\n", |
| " src=\"http://localhost:7241/project/parallelism_example/null/3cfc4776-5081-4428-aa9a-0e8fa3d3f005\"\n", |
| " frameborder=\"0\"\n", |
| " allowfullscreen\n", |
| " \n", |
| " ></iframe>\n", |
| " " |
| ], |
| "text/plain": [ |
| "<IPython.lib.display.IFrame at 0x12a2416c0>" |
| ] |
| }, |
| "execution_count": 10, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# hit the \"live button\" to see things run live\n", |
| "IFrame(f'{URL_FOR_UI}/project/parallelism_example/null/{app_recursion.uid}', width=\"100%\", height=\"700px\")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "id": "ff28d579-084f-463f-9761-b9e85a6d1460", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# let's run the application\n", |
| "action_, _, state = app_recursion.run(\n", |
| " inputs={\"prompt\": \"what is the meaning of life?\", \"models\" : [\"gpt-4\", \"gpt-4-turbo\", \"gpt-3.5-turbo\"]},\n", |
| " halt_after=[\"join_outputs\"]\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f8c09d92-045a-4285-840a-f0099fe8b7cf", |
| "metadata": {}, |
| "source": [ |
| "# Approach #3 -- using Burr's Parallel constructs\n", |
| "\n", |
| "We can use Burr's parallel construct to make running the differnet LLMs in parallel even simpler. To do so we need to bring in a Class that extends from `MapStates` (we can also map over actions with [`MapActions`](https://burr.dagworks.io/reference/parallelism/#burr.core.parallelism.MapActions) and both state & actions with [`MapStateAndActions`](https://burr.dagworks.io/reference/parallelism/#burr.core.parallelism.MapActionsAndStates)). This Class enables us to \"map\" over state -- in this case, the Class will return the same Action for each model in the `models` field, enabling you to run the models in parallel. Underneath, Burr will create the Burr Applications and run them in parallel. \n", |
| "\n", |
| "For more on the code constructs in this section see Burr's [parallel map/reduce concepts](https://burr.dagworks.io/concepts/parallelism/)." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "id": "ce90216e-5470-4bdb-9d32-c9b7c3e64eb3", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 12.0.0 (20240704.0754)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"230pt\" height=\"241pt\"\n", |
| " viewBox=\"0.00 0.00 229.85 241.40\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 237.4)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-237.4 225.85,-237.4 225.85,4 -4,4\"/>\n", |
| "<!-- process_input -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>process_input</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M152.1,-167.8C152.1,-167.8 70.5,-167.8 70.5,-167.8 64.5,-167.8 58.5,-161.8 58.5,-155.8 58.5,-155.8 58.5,-143.2 58.5,-143.2 58.5,-137.2 64.5,-131.2 70.5,-131.2 70.5,-131.2 152.1,-131.2 152.1,-131.2 158.1,-131.2 164.1,-137.2 164.1,-143.2 164.1,-143.2 164.1,-155.8 164.1,-155.8 164.1,-161.8 158.1,-167.8 152.1,-167.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-143.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">process_input</text>\n", |
| "</g>\n", |
| "<!-- query_multiple_models -->\n", |
| "<g id=\"node5\" class=\"node\">\n", |
| "<title>query_multiple_models</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M180.22,-102.2C180.22,-102.2 42.38,-102.2 42.38,-102.2 36.38,-102.2 30.38,-96.2 30.38,-90.2 30.38,-90.2 30.38,-77.6 30.38,-77.6 30.38,-71.6 36.38,-65.6 42.38,-65.6 42.38,-65.6 180.22,-65.6 180.22,-65.6 186.22,-65.6 192.22,-71.6 192.22,-77.6 192.22,-77.6 192.22,-90.2 192.22,-90.2 192.22,-96.2 186.22,-102.2 180.22,-102.2\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-78.1\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">query_multiple_models</text>\n", |
| "</g>\n", |
| "<!-- process_input->query_multiple_models -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>process_input->query_multiple_models</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M111.3,-130.78C111.3,-125.55 111.3,-119.67 111.3,-113.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"114.8,-114.03 111.3,-104.03 107.8,-114.03 114.8,-114.03\"/>\n", |
| "</g>\n", |
| "<!-- input__models -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>input__models</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"102.6,-233.4 0,-233.4 0,-196.8 102.6,-196.8 102.6,-233.4\"/>\n", |
| "<text text-anchor=\"middle\" x=\"51.3\" y=\"-209.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input: models</text>\n", |
| "</g>\n", |
| "<!-- input__models->process_input -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>input__models->process_input</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M68.01,-196.38C73.86,-190.18 80.56,-183.08 86.91,-176.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"89.32,-178.9 93.64,-169.22 84.23,-174.09 89.32,-178.9\"/>\n", |
| "</g>\n", |
| "<!-- input__prompt -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>input__prompt</title>\n", |
| "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"221.85,-233.4 120.75,-233.4 120.75,-196.8 221.85,-196.8 221.85,-233.4\"/>\n", |
| "<text text-anchor=\"middle\" x=\"171.3\" y=\"-209.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input: prompt</text>\n", |
| "</g>\n", |
| "<!-- input__prompt->process_input -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>input__prompt->process_input</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M154.59,-196.38C148.74,-190.18 142.04,-183.08 135.69,-176.35\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"138.37,-174.09 128.96,-169.22 133.28,-178.9 138.37,-174.09\"/>\n", |
| "</g>\n", |
| "<!-- join_outputs -->\n", |
| "<g id=\"node4\" class=\"node\">\n", |
| "<title>join_outputs</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.47,-36.6C146.47,-36.6 76.12,-36.6 76.12,-36.6 70.12,-36.6 64.12,-30.6 64.12,-24.6 64.12,-24.6 64.12,-12 64.12,-12 64.12,-6 70.12,0 76.12,0 76.12,0 146.47,0 146.47,0 152.47,0 158.47,-6 158.47,-12 158.47,-12 158.47,-24.6 158.47,-24.6 158.47,-30.6 152.47,-36.6 146.47,-36.6\"/>\n", |
| "<text text-anchor=\"middle\" x=\"111.3\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">join_outputs</text>\n", |
| "</g>\n", |
| "<!-- query_multiple_models->join_outputs -->\n", |
| "<g id=\"edge4\" class=\"edge\">\n", |
| "<title>query_multiple_models->join_outputs</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M111.3,-65.18C111.3,-59.95 111.3,-54.07 111.3,-48.32\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"114.8,-48.43 111.3,-38.43 107.8,-48.43 114.8,-48.43\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<burr.core.application.Application at 0x12a2950f0>" |
| ] |
| }, |
| "execution_count": 12, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "@action(reads=[\"model\", \"prompt\"], writes=[\"response\"])\n", |
| "def inner_query_model(state: State) -> State:\n", |
| " return state.update(response=_query_llm(state[\"model\"], state[\"prompt\"]))\n", |
| " \n", |
| "class RunOverMultiplePromptsAction(MapStates):\n", |
| " \"\"\"Our parallel action - defined as a class.\"\"\"\n", |
| " def states(\n", |
| " self, state: State, inputs: Dict[str, Any], context: ApplicationContext\n", |
| " ) -> Generator[State, None, None]:\n", |
| " \"\"\"Generator to map over values in state\"\"\"\n", |
| " for model in state[\"models\"]:\n", |
| " yield state.update(model=model)\n", |
| "\n", |
| " def action(self, state: State, inputs: Dict[str, Any]) -> Action:\n", |
| " \"\"\"The single action we want to map over\"\"\"\n", |
| " return inner_query_model\n", |
| "\n", |
| " def reduce(self, state: State, results: Generator[State, None, None]) -> State:\n", |
| " \"\"\"The reduction step - allows us to customize how state is updated\"\"\"\n", |
| " responses = {\n", |
| " model: output_state[\"response\"]\n", |
| " for model, output_state in zip(state[\"models\"], results)\n", |
| " }\n", |
| " return state.update(responses=responses)\n", |
| " \n", |
| " @property\n", |
| " def reads(self) -> List[str]:\n", |
| " \"\"\"Like the function @action we need to specify what is read from state\"\"\"\n", |
| " return [\"models\", \"prompt\"]\n", |
| "\n", |
| " @property\n", |
| " def writes(self) -> List[str]:\n", |
| " \"\"\"Like the function @action we need to specify what is written to state\"\"\"\n", |
| " return [\"responses\"]\n", |
| "\n", |
| "app_parallelism = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions(\n", |
| " process_input,\n", |
| " join_outputs,\n", |
| " query_multiple_models=RunOverMultiplePromptsAction(),\n", |
| " ).with_entrypoint(\"process_input\")\n", |
| " .with_tracker(project=\"parallelism_example\", use_otel_tracing=True)\n", |
| " .with_transitions(\n", |
| " (\"process_input\", \"query_multiple_models\"),\n", |
| " (\"query_multiple_models\", \"join_outputs\")\n", |
| " )\n", |
| " .build()\n", |
| ")\n", |
| "app_parallelism" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "id": "7ece3113-cf60-47b1-b375-735070c70b67", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/html": [ |
| "\n", |
| " <iframe\n", |
| " width=\"100%\"\n", |
| " height=\"700px\"\n", |
| " src=\"http://localhost:7241/project/parallelism_example/null/8f87303a-1cd2-4345-9143-69c7409a23e0\"\n", |
| " frameborder=\"0\"\n", |
| " allowfullscreen\n", |
| " \n", |
| " ></iframe>\n", |
| " " |
| ], |
| "text/plain": [ |
| "<IPython.lib.display.IFrame at 0x12a297dc0>" |
| ] |
| }, |
| "execution_count": 13, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# hit the \"live button\" to see things run live\n", |
| "IFrame(f'{URL_FOR_UI}/project/parallelism_example/null/{app_parallelism.uid}', width=\"100%\", height=\"700px\")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "id": "816cf0d7-1f89-4145-a5e3-c2878cec12e1", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# let's run the application\n", |
| "action_, _, state = app_parallelism.run(\n", |
| " inputs={\"prompt\": \"what is the meaning of life?\", \"models\" : [\"gpt-4\", \"gpt-4-turbo\", \"gpt-3.5-turbo\"]},\n", |
| " halt_after=[\"join_outputs\"]\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f36fa02b-a789-43d9-8161-447273209b0e", |
| "metadata": {}, |
| "source": [ |
| "# #4 - Building more complex Burr Applications\n", |
| "\n", |
| "Above we showed a progression on how you can tackle parallelism with Burr. Here we discuss and show sketches of examples that get more complex. Note: see [docs on executors](https://burr.dagworks.io/reference/application/#burr.core.application.ApplicationBuilder.with_parallel_executor) (and roadmap) as to what is supported right now.\n", |
| "\n", |
| "For example, you might want to take the same input, and map over several Burr applications/agents in parallel, e.g. for medical diagnosis.\n", |
| "Or you might want to map over some value in state and run the same application/agent on it, e.g. web-interactions.\n", |
| "Lastly, you might want to do the cartesian product between values in state and actions" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "22218d0f-e3c9-42ab-a948-87343b3ea58c", |
| "metadata": {}, |
| "source": [ |
| "## Running parallel medical diagnoses\n", |
| "If you get a transcript + images of a medical situation. You want to run some hypotheses in parallel. How would you do that?\n", |
| "\n", |
| "With Burr's parallel constructs, this would be a \"map over actions\", where \"action\" could be a simple action like above, or a full RAG agent. Below we sketch what the code would look like - using placeholder functions for core logic. " |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 15, |
| "id": "1969fe52-ba92-42f6-ae39-2252aff30d51", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 12.0.0 (20240704.0754)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"150pt\" height=\"176pt\"\n", |
| " viewBox=\"0.00 0.00 149.60 175.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 171.8)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-171.8 145.6,-171.8 145.6,4 -4,4\"/>\n", |
| "<!-- create_transcript -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>create_transcript</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M119.85,-167.8C119.85,-167.8 21.75,-167.8 21.75,-167.8 15.75,-167.8 9.75,-161.8 9.75,-155.8 9.75,-155.8 9.75,-143.2 9.75,-143.2 9.75,-137.2 15.75,-131.2 21.75,-131.2 21.75,-131.2 119.85,-131.2 119.85,-131.2 125.85,-131.2 131.85,-137.2 131.85,-143.2 131.85,-143.2 131.85,-155.8 131.85,-155.8 131.85,-161.8 125.85,-167.8 119.85,-167.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"70.8\" y=\"-143.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">create_transcript</text>\n", |
| "</g>\n", |
| "<!-- parallel_hypotheses -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>parallel_hypotheses</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M129.6,-102.2C129.6,-102.2 12,-102.2 12,-102.2 6,-102.2 0,-96.2 0,-90.2 0,-90.2 0,-77.6 0,-77.6 0,-71.6 6,-65.6 12,-65.6 12,-65.6 129.6,-65.6 129.6,-65.6 135.6,-65.6 141.6,-71.6 141.6,-77.6 141.6,-77.6 141.6,-90.2 141.6,-90.2 141.6,-96.2 135.6,-102.2 129.6,-102.2\"/>\n", |
| "<text text-anchor=\"middle\" x=\"70.8\" y=\"-78.1\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">parallel_hypotheses</text>\n", |
| "</g>\n", |
| "<!-- create_transcript->parallel_hypotheses -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>create_transcript->parallel_hypotheses</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M70.8,-130.78C70.8,-125.55 70.8,-119.67 70.8,-113.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"74.3,-114.03 70.8,-104.03 67.3,-114.03 74.3,-114.03\"/>\n", |
| "</g>\n", |
| "<!-- diagnose -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>diagnose</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M96.97,-36.6C96.97,-36.6 44.62,-36.6 44.62,-36.6 38.62,-36.6 32.62,-30.6 32.62,-24.6 32.62,-24.6 32.62,-12 32.62,-12 32.62,-6 38.62,0 44.62,0 44.62,0 96.97,0 96.97,0 102.97,0 108.97,-6 108.97,-12 108.97,-12 108.97,-24.6 108.97,-24.6 108.97,-30.6 102.97,-36.6 96.97,-36.6\"/>\n", |
| "<text text-anchor=\"middle\" x=\"70.8\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">diagnose</text>\n", |
| "</g>\n", |
| "<!-- parallel_hypotheses->diagnose -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>parallel_hypotheses->diagnose</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M70.8,-65.18C70.8,-59.95 70.8,-54.07 70.8,-48.32\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"74.3,-48.43 70.8,-38.43 67.3,-48.43 74.3,-48.43\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<burr.core.application.Application at 0x12a264700>" |
| ] |
| }, |
| "execution_count": 15, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "from burr.core import action, state, when\n", |
| "from burr.core.graph import GraphBuilder\n", |
| "from burr.core.parallelism import MapActions, RunnableGraph\n", |
| "from typing import Callable, Generator, List\n", |
| "\n", |
| "# actions in our sub-graphs / sub-application / agent\n", |
| "@action(reads=[\"transcript\"], writes=[\"query\", \"hypothesis\"])\n", |
| "def hypothesis_1(state: State) -> State:\n", |
| " _transcript = state[\"transcript\"]\n", |
| " # reads transcript, does - and applies hypothesis 1\n", |
| " return state.update(query=..., hypothesis=1)\n", |
| "\n", |
| "@action(reads=[\"transcript\"], writes=[\"query\", \"hypothesis\"])\n", |
| "def hypothesis_2(state: State) -> State:\n", |
| " _transcript = state[\"transcript\"]\n", |
| " # reads transcript, does - and applies hypothesis 2\n", |
| " return state.update(query=...,hypothesis=1)\n", |
| "\n", |
| "@action(reads=[\"query\"], writes=[\"context\"])\n", |
| "def query(state: State) -> State:\n", |
| " _query = state[\"query\"]\n", |
| " # does query against a datastore & gets context\n", |
| " return state.update(context=...)\n", |
| "\n", |
| "@action(reads=[\"context\"], writes=[\"response\"])\n", |
| "def llm_response(state:State) -> State:\n", |
| " _context = state[\"context\"]\n", |
| " # does LLM call using context, i.e. the AG part of RAG.\n", |
| " return state.update(response=...)\n", |
| "\n", |
| "# graph 1 - or agent 1\n", |
| "hypothesis_graph_1 = RunnableGraph(\n", |
| " graph=(\n", |
| " GraphBuilder()\n", |
| " .with_actions(\n", |
| " hypothesis_1,\n", |
| " query,\n", |
| " llm_response\n", |
| " )\n", |
| " .with_transitions(\n", |
| " (\"hypothesis_1\", \"query\"),\n", |
| " (\"query\", \"llm_response\"),\n", |
| " )\n", |
| " .build()\n", |
| " ),\n", |
| " entrypoint=\"hypothesis_1\",\n", |
| " halt_after=[\"llm_response\"],\n", |
| ")\n", |
| "\n", |
| "# graph 2 - or agent 2\n", |
| "hypothesis_graph_2 = RunnableGraph(\n", |
| " graph=(\n", |
| " GraphBuilder()\n", |
| " .with_actions(\n", |
| " hypothesis_2,\n", |
| " query,\n", |
| " llm_response\n", |
| " )\n", |
| " .with_transitions(\n", |
| " (\"hypothesis_2\", \"query\"),\n", |
| " (\"query\", \"llm_response\"),\n", |
| " )\n", |
| " .build()\n", |
| " ),\n", |
| " entrypoint=\"hypothesis_2\",\n", |
| " halt_after=[\"llm_response\"],\n", |
| ")\n", |
| "\n", |
| "class RunMultipleGraphsAction(MapActions):\n", |
| " \"\"\"Our parallel action that will map over graphs\"\"\"\n", |
| "\n", |
| " def actions(self, state: State, inputs: Dict[str, Any], context: ApplicationContext\n", |
| " ) -> Generator[Action | Callable | RunnableGraph, None, None]:\n", |
| " \"\"\"We hard code the actions here, but they could be dynamic\"\"\"\n", |
| " for graph_action in [\n", |
| " hypothesis_graph_1.with_name(\"hypothesis_1\"),\n", |
| " hypothesis_graph_2.with_name(\"hypothesis_2\"),\n", |
| " ]:\n", |
| " yield graph_action\n", |
| "\n", |
| " def state(self, state: State, inputs: Dict[str, Any]) -> State:\n", |
| " return state.update(transcript=\"this could be passed in or could be set here\")\n", |
| "\n", |
| " def reduce(self, state: State, states: Generator[State, None, None]) -> State:\n", |
| " # we aggregate here\n", |
| " all_diagnoses = []\n", |
| " for sub_state in states:\n", |
| " all_diagnoses.append((sub_state[\"llm_response\"], sub_state[\"hypothesis\"]))\n", |
| " return state.update(all_diagnoses=all_diagnoses)\n", |
| " \n", |
| " @property\n", |
| " def reads(self) -> List[str]:\n", |
| " return [\"transcript\"]\n", |
| "\n", |
| " @property\n", |
| " def writes(self) -> List[str]:\n", |
| " return [\"all_diagnoses\"]\n", |
| "\n", |
| "@action(reads=[\"audio\"], writes=[\"transcript\"])\n", |
| "def create_transcript(state: State) -> State:\n", |
| " # create transcript\n", |
| " return state.update(transcript=...)\n", |
| "\n", |
| "@action(reads=[\"all_diagnoses\"], writes=[\"diagnosis\"])\n", |
| "def diagnose(state:State) -> State:\n", |
| " # ... pick the best one\n", |
| " return state.update(diagnosis=...)\n", |
| "\n", |
| "map_actions_app = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions(\n", |
| " create_transcript=create_transcript,\n", |
| " parallel_hypotheses=RunMultipleGraphsAction(),\n", |
| " diagnose=diagnose\n", |
| " ).with_entrypoint(\"create_transcript\")\n", |
| " .with_tracker(project=\"parallelism_mapactions\", use_otel_tracing=True)\n", |
| " .with_transitions(\n", |
| " (\"create_transcript\", \"parallel_hypotheses\"),\n", |
| " (\"parallel_hypotheses\", \"diagnose\"),\n", |
| " )\n", |
| " .build()\n", |
| ")\n", |
| "map_actions_app\n", |
| "\n", |
| "# To run it you'd do something like:\n", |
| "# action, result, state = map_actions_app.run(\n", |
| "# halt_after=[\"diagnose\"], inputs={...}\n", |
| "# )" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "7f861d43-144d-402b-bb36-129d3aab8c08", |
| "metadata": {}, |
| "source": [ |
| "## Running parallel web interactions\n", |
| "\n", |
| "Given a list of web-sites (or tasks), wouldn't it be great to process them in parallel? Or perhaps you want to run some experiments over different inputs to your Burr application, this is one way to parallelize that, while keeping the original Burr application atomic and indpendently usable." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 16, |
| "id": "c881b815-b9bf-472d-9104-84ffa2fda63f", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 12.0.0 (20240704.0754)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"162pt\" height=\"176pt\"\n", |
| " viewBox=\"0.00 0.00 162.35 175.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 171.8)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-171.8 158.35,-171.8 158.35,4 -4,4\"/>\n", |
| "<!-- load_urls -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>load_urls</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M103.35,-167.8C103.35,-167.8 51,-167.8 51,-167.8 45,-167.8 39,-161.8 39,-155.8 39,-155.8 39,-143.2 39,-143.2 39,-137.2 45,-131.2 51,-131.2 51,-131.2 103.35,-131.2 103.35,-131.2 109.35,-131.2 115.35,-137.2 115.35,-143.2 115.35,-143.2 115.35,-155.8 115.35,-155.8 115.35,-161.8 109.35,-167.8 103.35,-167.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"77.17\" y=\"-143.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">load_urls</text>\n", |
| "</g>\n", |
| "<!-- parallel_url_processor -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>parallel_url_processor</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M142.35,-102.2C142.35,-102.2 12,-102.2 12,-102.2 6,-102.2 0,-96.2 0,-90.2 0,-90.2 0,-77.6 0,-77.6 0,-71.6 6,-65.6 12,-65.6 12,-65.6 142.35,-65.6 142.35,-65.6 148.35,-65.6 154.35,-71.6 154.35,-77.6 154.35,-77.6 154.35,-90.2 154.35,-90.2 154.35,-96.2 148.35,-102.2 142.35,-102.2\"/>\n", |
| "<text text-anchor=\"middle\" x=\"77.17\" y=\"-78.1\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">parallel_url_processor</text>\n", |
| "</g>\n", |
| "<!-- load_urls->parallel_url_processor -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>load_urls->parallel_url_processor</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M77.17,-130.78C77.17,-125.55 77.17,-119.67 77.17,-113.92\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"80.68,-114.03 77.18,-104.03 73.68,-114.03 80.68,-114.03\"/>\n", |
| "</g>\n", |
| "<!-- create_response -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>create_response</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M126.22,-36.6C126.22,-36.6 28.12,-36.6 28.12,-36.6 22.12,-36.6 16.12,-30.6 16.12,-24.6 16.12,-24.6 16.12,-12 16.12,-12 16.12,-6 22.12,0 28.12,0 28.12,0 126.22,0 126.22,0 132.22,0 138.22,-6 138.22,-12 138.22,-12 138.22,-24.6 138.22,-24.6 138.22,-30.6 132.22,-36.6 126.22,-36.6\"/>\n", |
| "<text text-anchor=\"middle\" x=\"77.17\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">create_response</text>\n", |
| "</g>\n", |
| "<!-- parallel_url_processor->create_response -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>parallel_url_processor->create_response</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M77.17,-65.18C77.17,-59.95 77.17,-54.07 77.17,-48.32\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"80.68,-48.43 77.18,-38.43 73.68,-48.43 80.68,-48.43\"/>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<burr.core.application.Application at 0x12a22aa40>" |
| ] |
| }, |
| "execution_count": 16, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "from burr.core import action, state, when\n", |
| "from burr.core.graph import GraphBuilder\n", |
| "from burr.core.parallelism import MapStates, RunnableGraph\n", |
| "from typing import Callable, Generator, List\n", |
| "\n", |
| "# actions in our sub-graphs / sub-application / agent\n", |
| "@action(reads=[\"url\"], writes=[\"data\"])\n", |
| "def crawler(state: State) -> State:\n", |
| " _url = state[\"url\"]\n", |
| " # goes to the URL and extracts something...\n", |
| " return state.update(data=...)\n", |
| "\n", |
| "@action(reads=[\"data\"], writes=[\"summary\"])\n", |
| "def summarize(state: State) -> State:\n", |
| " _data = state[\"data\"]\n", |
| " # summarizes the data\n", |
| " return state.update(summary=...)\n", |
| "\n", |
| "@action(reads=[\"summary\", \"url\"], writes=[\"classification\"])\n", |
| "def classify_and_store(state: State) -> State:\n", |
| " _context = state[\"summary\"]\n", |
| " _url = state[\"url\"]\n", |
| " # store data somewhere? or classify it? ... use your imagination here\n", |
| " return state.update(classification=...)\n", |
| "\n", |
| "# sub graph 1 - or agent 1\n", |
| "url_processor = RunnableGraph(\n", |
| " graph=(\n", |
| " GraphBuilder()\n", |
| " .with_actions(\n", |
| " crawler,\n", |
| " summarize,\n", |
| " classify_and_store\n", |
| " )\n", |
| " .with_transitions(\n", |
| " (\"crawler\", \"summarize\"),\n", |
| " (\"summarize\", \"classify_and_store\"),\n", |
| " )\n", |
| " .build()\n", |
| " ),\n", |
| " entrypoint=\"crawler\",\n", |
| " halt_after=[\"classify_and_store\"],\n", |
| ")\n", |
| "\n", |
| "\n", |
| "class MapOverURLsAction(MapStates):\n", |
| " \"\"\"Our parallel action that will map over state values\"\"\"\n", |
| "\n", |
| " def action(self, state: State, inputs: Dict[str, Any]) -> Generator[Action | Callable | RunnableGraph, None, None]:\n", |
| " \"\"\"Return the one action.\"\"\"\n", |
| " return url_processor\n", |
| "\n", |
| " def states(\n", |
| " self, state: State, context: ApplicationContext, inputs: Dict[str, Any]\n", |
| " ) -> Generator[State, None, None]:\n", |
| " \"\"\"Generator to map over values in state\"\"\"\n", |
| " for url in state[\"urls\"]:\n", |
| " yield state.update(url=url)\n", |
| "\n", |
| " def reduce(self, state: State, states: Generator[State, None, None]) -> State:\n", |
| " # we aggregate here\n", |
| " all_results = []\n", |
| " for sub_state in states:\n", |
| " all_results.append((sub_state[\"classification\"], sub_state[\"url\"]))\n", |
| " return state.update(all_results=all_results)\n", |
| " \n", |
| " @property\n", |
| " def reads(self) -> List[str]:\n", |
| " return [\"urls\"]\n", |
| "\n", |
| " @property\n", |
| " def writes(self) -> List[str]:\n", |
| " return [\"all_results\"]\n", |
| "\n", |
| "@action(reads=[\"...\"], writes=[\"urls\"])\n", |
| "def load_urls(state: State) -> State:\n", |
| " # create transcript\n", |
| " return state.update(urls=...)\n", |
| "\n", |
| "@action(reads=[\"all_results\"], writes=[\"response\"])\n", |
| "def create_response(state:State) -> State:\n", |
| " # ... return results ...\n", |
| " return state.update(response=...)\n", |
| "\n", |
| "map_states_app = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions(\n", |
| " load_urls=load_urls,\n", |
| " parallel_url_processor=MapOverURLsAction(),\n", |
| " create_response=create_response\n", |
| " ).with_entrypoint(\"load_urls\")\n", |
| " .with_tracker(project=\"parallelism_mapstates\", use_otel_tracing=True)\n", |
| " .with_transitions(\n", |
| " (\"load_urls\", \"parallel_url_processor\"),\n", |
| " (\"parallel_url_processor\", \"create_response\"),\n", |
| " )\n", |
| " .build()\n", |
| ")\n", |
| "map_states_app\n", |
| "\n", |
| "# To run it you'd do something like:\n", |
| "# action, result, state = map_states_app.run(\n", |
| "# halt_after=[\"create_response\"], inputs={...}\n", |
| "# )\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f1029bff-8055-49b8-9e16-dc9bbf737466", |
| "metadata": {}, |
| "source": [ |
| "## Running \"hyperparameter tuning\": i.e. experiments over parameters of an agent (e.g. prompt), RAG systems (e.g. embeddings), etc.\n", |
| "\n", |
| "Optimizations, hyperparameter tuning, or more simply figuring out what prompts, LLM model, and parameters to use, is an exercise that can be parallelized. With this new functionality, you could can invoke and run different sets of parameters / inputs over different Burr applications. You can then use the results of which to make more informed decisions as to what set of parameters, or what conditions lead to better or worse outcomes. \n", |
| "\n", |
| "Below we sketch the structure of how you could use the MapStateAndActions to achieve this using Actions (which is trivial to swap out for whole Burr graphs)." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 17, |
| "id": "66f78c91-09db-49df-9c20-6f084783dc7a", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "image/svg+xml": [ |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", |
| "<!-- Generated by graphviz version 12.0.0 (20240704.0754)\n", |
| " -->\n", |
| "<!-- Pages: 1 -->\n", |
| "<svg width=\"250pt\" height=\"180pt\"\n", |
| " viewBox=\"0.00 0.00 249.93 179.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 175.8)\">\n", |
| "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-175.8 245.93,-175.8 245.93,4 -4,4\"/>\n", |
| "<!-- create_prompts -->\n", |
| "<g id=\"node1\" class=\"node\">\n", |
| "<title>create_prompts</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M185.48,-171.8C185.48,-171.8 93.38,-171.8 93.38,-171.8 87.38,-171.8 81.38,-165.8 81.38,-159.8 81.38,-159.8 81.38,-147.2 81.38,-147.2 81.38,-141.2 87.38,-135.2 93.38,-135.2 93.38,-135.2 185.48,-135.2 185.48,-135.2 191.48,-135.2 197.48,-141.2 197.48,-147.2 197.48,-147.2 197.48,-159.8 197.48,-159.8 197.48,-165.8 191.48,-171.8 185.48,-171.8\"/>\n", |
| "<text text-anchor=\"middle\" x=\"139.43\" y=\"-147.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">create_prompts</text>\n", |
| "</g>\n", |
| "<!-- cartesian_product -->\n", |
| "<g id=\"node2\" class=\"node\">\n", |
| "<title>cartesian_product</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M116.85,-104.2C116.85,-104.2 12,-104.2 12,-104.2 6,-104.2 0,-98.2 0,-92.2 0,-92.2 0,-79.6 0,-79.6 0,-73.6 6,-67.6 12,-67.6 12,-67.6 116.85,-67.6 116.85,-67.6 122.85,-67.6 128.85,-73.6 128.85,-79.6 128.85,-79.6 128.85,-92.2 128.85,-92.2 128.85,-98.2 122.85,-104.2 116.85,-104.2\"/>\n", |
| "<text text-anchor=\"middle\" x=\"64.43\" y=\"-80.1\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">cartesian_product</text>\n", |
| "</g>\n", |
| "<!-- create_prompts->cartesian_product -->\n", |
| "<g id=\"edge1\" class=\"edge\">\n", |
| "<title>create_prompts->cartesian_product</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M119.33,-134.92C111.26,-127.86 101.8,-119.59 93.04,-111.93\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"95.55,-109.47 85.72,-105.52 90.94,-114.74 95.55,-109.47\"/>\n", |
| "</g>\n", |
| "<!-- evaluate_responses -->\n", |
| "<g id=\"node3\" class=\"node\">\n", |
| "<title>evaluate_responses</title>\n", |
| "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M173.6,-36.6C173.6,-36.6 55.25,-36.6 55.25,-36.6 49.25,-36.6 43.25,-30.6 43.25,-24.6 43.25,-24.6 43.25,-12 43.25,-12 43.25,-6 49.25,0 55.25,0 55.25,0 173.6,0 173.6,0 179.6,0 185.6,-6 185.6,-12 185.6,-12 185.6,-24.6 185.6,-24.6 185.6,-30.6 179.6,-36.6 173.6,-36.6\"/>\n", |
| "<text text-anchor=\"middle\" x=\"114.43\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">evaluate_responses</text>\n", |
| "</g>\n", |
| "<!-- cartesian_product->evaluate_responses -->\n", |
| "<g id=\"edge2\" class=\"edge\">\n", |
| "<title>cartesian_product->evaluate_responses</title>\n", |
| "<path fill=\"none\" stroke=\"black\" d=\"M77.82,-67.32C82.87,-60.7 88.72,-53.02 94.25,-45.77\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"96.79,-48.21 100.07,-38.13 91.23,-43.96 96.79,-48.21\"/>\n", |
| "</g>\n", |
| "<!-- evaluate_responses->create_prompts -->\n", |
| "<g id=\"edge3\" class=\"edge\">\n", |
| "<title>evaluate_responses->create_prompts</title>\n", |
| "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M125.05,-36.98C129.8,-45.84 134.88,-56.94 137.43,-67.6 141.8,-85.95 142.3,-107.16 141.71,-123.8\"/>\n", |
| "<polygon fill=\"black\" stroke=\"black\" points=\"138.23,-123.31 141.2,-133.48 145.22,-123.68 138.23,-123.31\"/>\n", |
| "<text text-anchor=\"middle\" x=\"191.68\" y=\"-80.85\" font-family=\"Times,serif\" font-size=\"14.00\">should_loop=True</text>\n", |
| "</g>\n", |
| "</g>\n", |
| "</svg>\n" |
| ], |
| "text/plain": [ |
| "<burr.core.application.Application at 0x12a422f20>" |
| ] |
| }, |
| "execution_count": 17, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "from burr.core import action, state, when\n", |
| "from burr.core.parallelism import MapActionsAndStates, RunnableGraph\n", |
| "from typing import Callable, Generator, List\n", |
| "import openai\n", |
| "import anthropic\n", |
| "\n", |
| "@action(reads=[\"prompt\"], writes=[\"llm_output\"])\n", |
| "def query_llm_openai(state: State) -> State:\n", |
| " \"\"\"openai action\"\"\"\n", |
| " prompt = state[\"prompt\"]\n", |
| " client = openai.Client()\n", |
| " response = client.chat.completions.create(\n", |
| " model=\"gpt4o-mini\",\n", |
| " messages=[{\"role\": \"user\", \"content\": prompt}]\n", |
| " )\n", |
| " llm_output = response.choices[0].message.content\n", |
| " return state.update(llm_output=(llm_output, \"openai\"))\n", |
| "\n", |
| "@action(reads=[\"prompt\"], writes=[\"llm_output\"])\n", |
| "def query_llm_claude(state: State, model: str) -> State:\n", |
| " \"\"\"claude action\"\"\"\n", |
| " prompt = state[\"prompt\"]\n", |
| " client = anthropic.Anthropic(\n", |
| " # defaults to os.environ.get(\"ANTHROPIC_API_KEY\")\n", |
| " api_key=\"my_api_key\",\n", |
| " )\n", |
| " message = client.messages.create(\n", |
| " model=\"claude-3-5-sonnet-20241022\",\n", |
| " max_tokens=1024,\n", |
| " messages=[{\"role\": \"user\", \"content\": prompt}]\n", |
| " )\n", |
| " return state.update(llm_output=(message.content, \"anthropic\"))\n", |
| "\n", |
| "# def query_llm_ollama(...)\n", |
| "\n", |
| "class TestModelsOverPromptsAction(MapActionsAndStates):\n", |
| " \"\"\"This is our aciton that will create a cartesian product over the states and actions returned.\n", |
| "\n", |
| " E.g. states x actions burr subapplications will be run in parallel.\n", |
| " \"\"\"\n", |
| "\n", |
| " def actions(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[Action | Callable | RunnableGraph, None, None]:\n", |
| " # make sure to add a name to the action\n", |
| " # This is not necessary for subgraphs, as actions will already have names\n", |
| " for action in [\n", |
| " query_llm_openai.with_name(\"openai\"),\n", |
| " query_llm_claude.with_name(\"claude\"),\n", |
| " #query_llm_ollama.with_name(\"ollama\"),\n", |
| " ]:\n", |
| " yield action\n", |
| "\n", |
| " def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[State, None, None]:\n", |
| " for prompt in state[\"prompts\"]:\n", |
| " yield state.update(prompt=prompt)\n", |
| "\n", |
| " def reduce(self, state: State, states: Generator[State, None, None]) -> State:\n", |
| " all_llm_outputs = []\n", |
| " for sub_state in states:\n", |
| " all_llm_outputs.append(\n", |
| " {\n", |
| " \"output\" : sub_state[\"llm_output\"][0],\n", |
| " \"provider\" : sub_state[\"llm_output\"][1],\n", |
| " \"prompt\" : sub_state[\"prompt\"],\n", |
| " }\n", |
| " )\n", |
| " return state.update(all_llm_outputs=all_llm_outputs)\n", |
| "\n", |
| " @property\n", |
| " def reads(self) -> List[str]:\n", |
| " return [\"prompts\"]\n", |
| "\n", |
| " @property\n", |
| " def writes(self) -> List[str]:\n", |
| " return [\"all_llm_outputs\"]\n", |
| "\n", |
| "@action(reads=[\"best_one\"], writes=[\"prompts\"])\n", |
| "def create_prompts(state: State) -> State:\n", |
| " \"\"\"action that comes up with prompts - here we hard code for the example, but it could be dynamic.\"\"\"\n", |
| " prompts = [\n", |
| " \"What is the meaning of life?\",\n", |
| " \"What is the airspeed velocity of an unladen swallow?\",\n", |
| " \"What is the best way to cook a steak?\",\n", |
| " ]\n", |
| " return state.update(prompts=prompts)\n", |
| "\n", |
| "@action(reads=[\"all_llm_outputs\"], writes=[\"best_one\"])\n", |
| "def evaluate_responses(state: State) -> State:\n", |
| " # some code here...\n", |
| " return state.update(best_one=..., should_loop=...)\n", |
| " \n", |
| "\n", |
| "map_actions_and_state_app = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions(\n", |
| " create_prompts=create_prompts,\n", |
| " cartesian_product=TestModelsOverPromptsAction(),\n", |
| " evaluate_responses=evaluate_responses\n", |
| " ).with_entrypoint(\"create_prompts\")\n", |
| " .with_tracker(project=\"parallelism_mapstateandactions\", use_otel_tracing=True)\n", |
| " .with_transitions(\n", |
| " (\"create_prompts\", \"cartesian_product\"),\n", |
| " (\"cartesian_product\", \"evaluate_responses\"),\n", |
| " (\"evaluate_responses\", \"create_prompts\", when(should_loop=True)),\n", |
| " )\n", |
| " .build()\n", |
| ")\n", |
| "map_actions_and_state_app\n", |
| "\n", |
| "# To run it you'd do something like:\n", |
| "# action, result, state = map_actions_and_state_app.run(\n", |
| "# halt_after=[\"evaluate_responses\"], inputs={...}\n", |
| "# )" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "c9444365-7c84-473a-8f9f-fd20e9863d94", |
| "metadata": {}, |
| "source": [ |
| "# To close\n", |
| "We're excited by what you can now model and importantly observe & iterate with Burr + the Burr UI (e.g. see [test-case creation](https://burr.dagworks.io/examples/guardrails/creating_tests/), [time-travel](https://blog.dagworks.io/p/travel-back-in-time-with-burr), or [annotation](https://blog.dagworks.io/p/annotating-data-in-burr?r=2cg5z1&utm_campaign=post&utm_medium=web)). \n", |
| "\n", |
| "We have an active roadmap planned (we're looking for contributors!), and if you like what you see, have thoughts / or questions, please drop by our discord community -> [](https://discord.gg/6Zy2DwP4f3)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "67436e78-82c7-4a62-84c0-d81dcd53e428", |
| "metadata": {}, |
| "outputs": [], |
| "source": [] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.10.4" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |