blob: 016e2f30aad387d8d0c359f9ef6a38282cc59c34 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "1622e1563a35aa32",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Conversational RAG Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9069a73d207fd136",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"!pip install burr"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:47:37.515883Z",
"start_time": "2024-03-26T20:47:35.314767Z"
}
},
"outputs": [],
"source": [
"# Importing the necessary libraries\n",
"import pprint\n",
"from typing import Tuple\n",
"from hamilton import dataflows, driver\n",
"import burr.core\n",
"from burr.core import ApplicationBuilder, State, default, expr\n",
"from burr.core.action import action\n",
"from application import PrintStepHook # local import from application.py\n",
"from burr.tracking import LocalTrackingClient"
]
},
{
"cell_type": "markdown",
"id": "8f1d578469a918b1",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Load your \"chain\" or conversational RAG \"pipeline\"\n",
"\n",
"We use Hamilton here. But you could use LangChain, etc., or forgo them and write your own code."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c6a018aff1154f0b",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:47:42.057246Z",
"start_time": "2024-03-26T20:47:40.500075Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 9.0.0 (20230911.1827)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"1051pt\" height=\"352pt\"\n",
" viewBox=\"0.00 0.00 1050.60 352.30\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 348.3)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-348.3 1046.6,-348.3 1046.6,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>cluster__legend</title>\n",
"<polygon fill=\"#ffffff\" stroke=\"black\" points=\"19.38,-206.3 19.38,-336.3 104.23,-336.3 104.23,-206.3 19.38,-206.3\"/>\n",
"<text text-anchor=\"middle\" x=\"61.8\" y=\"-319\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
"</g>\n",
"<!-- context -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>context</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M636.15,-137.1C636.15,-137.1 590.55,-137.1 590.55,-137.1 584.55,-137.1 578.55,-131.1 578.55,-125.1 578.55,-125.1 578.55,-85.5 578.55,-85.5 578.55,-79.5 584.55,-73.5 590.55,-73.5 590.55,-73.5 636.15,-73.5 636.15,-73.5 642.15,-73.5 648.15,-79.5 648.15,-85.5 648.15,-85.5 648.15,-125.1 648.15,-125.1 648.15,-131.1 642.15,-137.1 636.15,-137.1\"/>\n",
"<text text-anchor=\"start\" x=\"589.35\" y=\"-114\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">context</text>\n",
"<text text-anchor=\"start\" x=\"605.85\" y=\"-86\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
"</g>\n",
"<!-- answer_prompt -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>answer_prompt</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M788.75,-208.1C788.75,-208.1 689.15,-208.1 689.15,-208.1 683.15,-208.1 677.15,-202.1 677.15,-196.1 677.15,-196.1 677.15,-156.5 677.15,-156.5 677.15,-150.5 683.15,-144.5 689.15,-144.5 689.15,-144.5 788.75,-144.5 788.75,-144.5 794.75,-144.5 800.75,-150.5 800.75,-156.5 800.75,-156.5 800.75,-196.1 800.75,-196.1 800.75,-202.1 794.75,-208.1 788.75,-208.1\"/>\n",
"<text text-anchor=\"start\" x=\"687.95\" y=\"-185\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">answer_prompt</text>\n",
"<text text-anchor=\"start\" x=\"731.45\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
"</g>\n",
"<!-- context&#45;&gt;answer_prompt -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>context&#45;&gt;answer_prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M648.45,-124.89C655.95,-129.2 664.13,-133.9 672.41,-138.65\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"670.4,-141.53 680.81,-143.48 673.89,-135.46 670.4,-141.53\"/>\n",
"</g>\n",
"<!-- conversational_rag_response -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>conversational_rag_response</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1030.6,-238.1C1030.6,-238.1 841.75,-238.1 841.75,-238.1 835.75,-238.1 829.75,-232.1 829.75,-226.1 829.75,-226.1 829.75,-186.5 829.75,-186.5 829.75,-180.5 835.75,-174.5 841.75,-174.5 841.75,-174.5 1030.6,-174.5 1030.6,-174.5 1036.6,-174.5 1042.6,-180.5 1042.6,-186.5 1042.6,-186.5 1042.6,-226.1 1042.6,-226.1 1042.6,-232.1 1036.6,-238.1 1030.6,-238.1\"/>\n",
"<text text-anchor=\"start\" x=\"840.55\" y=\"-215\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">conversational_rag_response</text>\n",
"<text text-anchor=\"start\" x=\"928.68\" y=\"-187\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
"</g>\n",
"<!-- standalone_question -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>standalone_question</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M537.55,-208.1C537.55,-208.1 404.95,-208.1 404.95,-208.1 398.95,-208.1 392.95,-202.1 392.95,-196.1 392.95,-196.1 392.95,-156.5 392.95,-156.5 392.95,-150.5 398.95,-144.5 404.95,-144.5 404.95,-144.5 537.55,-144.5 537.55,-144.5 543.55,-144.5 549.55,-150.5 549.55,-156.5 549.55,-156.5 549.55,-196.1 549.55,-196.1 549.55,-202.1 543.55,-208.1 537.55,-208.1\"/>\n",
"<text text-anchor=\"start\" x=\"403.75\" y=\"-185\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">standalone_question</text>\n",
"<text text-anchor=\"start\" x=\"463.75\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
"</g>\n",
"<!-- standalone_question&#45;&gt;context -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>standalone_question&#45;&gt;context</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M535.74,-144.12C546.71,-138.57 557.89,-132.9 568.21,-127.67\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"569.5,-130.94 576.84,-123.3 566.34,-124.69 569.5,-130.94\"/>\n",
"</g>\n",
"<!-- standalone_question&#45;&gt;answer_prompt -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>standalone_question&#45;&gt;answer_prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M549.82,-176.3C586.3,-176.3 629.57,-176.3 665.4,-176.3\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"665.24,-179.8 675.24,-176.3 665.24,-172.8 665.24,-179.8\"/>\n",
"</g>\n",
"<!-- llm_client -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>llm_client</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M288.58,-277.1C288.58,-277.1 227.98,-277.1 227.98,-277.1 221.98,-277.1 215.98,-271.1 215.98,-265.1 215.98,-265.1 215.98,-225.5 215.98,-225.5 215.98,-219.5 221.98,-213.5 227.98,-213.5 227.98,-213.5 288.58,-213.5 288.58,-213.5 294.58,-213.5 300.58,-219.5 300.58,-225.5 300.58,-225.5 300.58,-265.1 300.58,-265.1 300.58,-271.1 294.58,-277.1 288.58,-277.1\"/>\n",
"<text text-anchor=\"start\" x=\"226.78\" y=\"-254\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">llm_client</text>\n",
"<text text-anchor=\"start\" x=\"235.03\" y=\"-226\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">OpenAI</text>\n",
"</g>\n",
"<!-- llm_client&#45;&gt;conversational_rag_response -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>llm_client&#45;&gt;conversational_rag_response</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M300.74,-243.58C391.45,-239.73 614.31,-229.75 800.75,-217.3 806.44,-216.92 812.27,-216.51 818.16,-216.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"818.31,-219.59 828.02,-215.36 817.79,-212.6 818.31,-219.59\"/>\n",
"</g>\n",
"<!-- llm_client&#45;&gt;standalone_question -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>llm_client&#45;&gt;standalone_question</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M300.96,-231.67C324.17,-224.08 353.99,-214.32 382.11,-205.13\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"382.94,-208.54 391.35,-202.1 380.76,-201.89 382.94,-208.54\"/>\n",
"</g>\n",
"<!-- answer_prompt&#45;&gt;conversational_rag_response -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>answer_prompt&#45;&gt;conversational_rag_response</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M801.18,-185.71C806.7,-186.56 812.4,-187.43 818.2,-188.32\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"817.53,-191.76 827.94,-189.82 818.59,-184.84 817.53,-191.76\"/>\n",
"</g>\n",
"<!-- standalone_question_prompt -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>standalone_question_prompt</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M351.95,-195.1C351.95,-195.1 164.6,-195.1 164.6,-195.1 158.6,-195.1 152.6,-189.1 152.6,-183.1 152.6,-183.1 152.6,-143.5 152.6,-143.5 152.6,-137.5 158.6,-131.5 164.6,-131.5 164.6,-131.5 351.95,-131.5 351.95,-131.5 357.95,-131.5 363.95,-137.5 363.95,-143.5 363.95,-143.5 363.95,-183.1 363.95,-183.1 363.95,-189.1 357.95,-195.1 351.95,-195.1\"/>\n",
"<text text-anchor=\"start\" x=\"163.4\" y=\"-172\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">standalone_question_prompt</text>\n",
"<text text-anchor=\"start\" x=\"250.78\" y=\"-144\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
"</g>\n",
"<!-- standalone_question_prompt&#45;&gt;standalone_question -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>standalone_question_prompt&#45;&gt;standalone_question</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M364.14,-169.76C369.85,-170.11 375.57,-170.47 381.22,-170.81\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"380.99,-174.31 391.19,-171.43 381.42,-167.32 380.99,-174.31\"/>\n",
"</g>\n",
"<!-- vector_store -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>vector_store</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M533.05,-126.1C533.05,-126.1 409.45,-126.1 409.45,-126.1 403.45,-126.1 397.45,-120.1 397.45,-114.1 397.45,-114.1 397.45,-74.5 397.45,-74.5 397.45,-68.5 403.45,-62.5 409.45,-62.5 409.45,-62.5 533.05,-62.5 533.05,-62.5 539.05,-62.5 545.05,-68.5 545.05,-74.5 545.05,-74.5 545.05,-114.1 545.05,-114.1 545.05,-120.1 539.05,-126.1 533.05,-126.1\"/>\n",
"<text text-anchor=\"start\" x=\"430.75\" y=\"-103\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">vector_store</text>\n",
"<text text-anchor=\"start\" x=\"408.25\" y=\"-75\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">VectorStoreRetriever</text>\n",
"</g>\n",
"<!-- vector_store&#45;&gt;context -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>vector_store&#45;&gt;context</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M545.17,-100.03C552.58,-100.61 559.95,-101.19 566.94,-101.74\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"566.55,-105.21 576.79,-102.51 567.1,-98.24 566.55,-105.21\"/>\n",
"</g>\n",
"<!-- _context_inputs -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>_context_inputs</title>\n",
"<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"512.68,-44.6 429.83,-44.6 429.83,0 512.68,0 512.68,-44.6\"/>\n",
"<text text-anchor=\"start\" x=\"444.63\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">top_k</text>\n",
"<text text-anchor=\"start\" x=\"483.63\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
"</g>\n",
"<!-- _context_inputs&#45;&gt;context -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>_context_inputs&#45;&gt;context</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M513.1,-35.95C525.2,-40.71 538.24,-46.57 549.55,-53.3 556.33,-57.33 563.12,-62.07 569.62,-67.01\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"567.05,-69.44 577.06,-72.9 571.39,-63.95 567.05,-69.44\"/>\n",
"</g>\n",
"<!-- _standalone_question_prompt_inputs -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>_standalone_question_prompt_inputs</title>\n",
"<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"123.6,-196.1 0,-196.1 0,-130.5 123.6,-130.5 123.6,-196.1\"/>\n",
"<text text-anchor=\"start\" x=\"25.3\" y=\"-168\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">question</text>\n",
"<text text-anchor=\"start\" x=\"93.3\" y=\"-168\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
"<text text-anchor=\"start\" x=\"14.43\" y=\"-147\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">chat_history</text>\n",
"<text text-anchor=\"start\" x=\"92.55\" y=\"-147\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">list</text>\n",
"</g>\n",
"<!-- _standalone_question_prompt_inputs&#45;&gt;standalone_question_prompt -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>_standalone_question_prompt_inputs&#45;&gt;standalone_question_prompt</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M124.07,-163.3C129.56,-163.3 135.24,-163.3 141,-163.3\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"140.67,-166.8 150.67,-163.3 140.67,-159.8 140.67,-166.8\"/>\n",
"</g>\n",
"<!-- _vector_store_inputs -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>_vector_store_inputs</title>\n",
"<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"316.95,-113.6 199.6,-113.6 199.6,-69 316.95,-69 316.95,-113.6\"/>\n",
"<text text-anchor=\"start\" x=\"214.4\" y=\"-85.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input_texts</text>\n",
"<text text-anchor=\"start\" x=\"285.65\" y=\"-85.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">list</text>\n",
"</g>\n",
"<!-- _vector_store_inputs&#45;&gt;vector_store -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>_vector_store_inputs&#45;&gt;vector_store</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M317.2,-92.12C338.28,-92.42 362.57,-92.77 385.49,-93.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"385.41,-96.59 395.46,-93.24 385.51,-89.59 385.41,-96.59\"/>\n",
"</g>\n",
"<!-- input -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>input</title>\n",
"<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"88.8,-305.6 34.8,-305.6 34.8,-269 88.8,-269 88.8,-305.6\"/>\n",
"<text text-anchor=\"middle\" x=\"61.8\" y=\"-281.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
"</g>\n",
"<!-- function -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>function</title>\n",
"<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M84.23,-250.6C84.23,-250.6 39.38,-250.6 39.38,-250.6 33.38,-250.6 27.38,-244.6 27.38,-238.6 27.38,-238.6 27.38,-226 27.38,-226 27.38,-220 33.38,-214 39.38,-214 39.38,-214 84.23,-214 84.23,-214 90.23,-214 96.23,-220 96.23,-226 96.23,-226 96.23,-238.6 96.23,-238.6 96.23,-244.6 90.23,-250.6 84.23,-250.6\"/>\n",
"<text text-anchor=\"middle\" x=\"61.8\" y=\"-226.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x12ad5de40>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Loads Hamilton DAG\n",
"conversational_rag = dataflows.import_module(\"conversational_rag\")\n",
"conversational_rag_driver = (\n",
" driver.Builder()\n",
" .with_config({}) # replace with configuration as appropriate\n",
" .with_modules(conversational_rag)\n",
" .build()\n",
")\n",
"conversational_rag_driver.display_all_functions()"
]
},
{
"cell_type": "markdown",
"id": "82b3515afd2de6e4",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Create the actions that will constitute our application\n",
"\n",
"We will use the functional (vs class) approach to declaring actions here. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6433dad5abc6eb16",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:47:42.065785Z",
"start_time": "2024-03-26T20:47:42.059539Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"@action(\n",
" reads=[\"question\", \"chat_history\"],\n",
" writes=[\"chat_history\"],\n",
")\n",
"def ai_converse(state: State, vector_store: object) -> Tuple[dict, State]:\n",
" \"\"\"AI conversing step. Uses Hamilton to execute the conversational pipeline.\"\"\"\n",
" result = conversational_rag_driver.execute(\n",
" [\"conversational_rag_response\"],\n",
" inputs={\n",
" \"question\": state[\"question\"],\n",
" \"chat_history\": state[\"chat_history\"],\n",
" },\n",
" # we use overrides here because we want to pass in the vector store\n",
" overrides={\n",
" \"vector_store\": vector_store,\n",
" }\n",
" )\n",
" new_history = f\"AI: {result['conversational_rag_response']}\"\n",
" return result, state.append(chat_history=new_history)\n",
"\n",
"\n",
"@action(\n",
" reads=[],\n",
" writes=[\"question\", \"chat_history\"],\n",
")\n",
"def human_converse(state: State, user_question: str) -> Tuple[dict, State]:\n",
" \"\"\"Human converse step -- make sure we get input, and store it as state.\"\"\"\n",
" state = state.update(question=user_question).append(chat_history=f\"Human: {user_question}\")\n",
" return {\"question\": user_question}, state"
]
},
{
"cell_type": "markdown",
"id": "9579b47aac2c53a0",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Create the application\n",
"\n",
"We now create the application, which is a collection of actions, and then set the transitions between the actions based on values in State.\n",
"\n",
"We also intialize initial values etc to populate the application with."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "77e9f67b660a0953",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:50:25.642660Z",
"start_time": "2024-03-26T20:50:25.616245Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# what we will do RAG over.\n",
"initial_documents = [\n",
" \"harrison worked at kensho\",\n",
" \"stefan worked at Stitch Fix\",\n",
" \"stefan likes tacos\",\n",
" \"elijah worked at TwoSigma\",\n",
" \"elijah likes mango\",\n",
" \"stefan used to work at IBM\",\n",
" \"elijah likes to go biking\",\n",
" \"stefan likes to bake sourdough\",\n",
"]\n",
"# bootstrap the vector store;\n",
"vector_store = conversational_rag_driver.execute(\n",
" [\"vector_store\"],\n",
" inputs={\"input_texts\": initial_documents})[\"vector_store\"]\n",
"# what we will initialize the application with\n",
"initial_state = {\n",
" \"question\": \"\",\n",
" \"chat_history\": [],\n",
"}\n",
"import uuid\n",
"app_id = str(uuid.uuid4())\n",
"app = (\n",
" ApplicationBuilder()\n",
" # add the actions\n",
" .with_actions(\n",
" # bind the vector store to the AI conversational step\n",
" ai_converse=ai_converse.bind(vector_store=vector_store),\n",
" human_converse=human_converse,\n",
" terminal=burr.core.Result(\"chat_history\"),\n",
" )\n",
" # set the transitions between actions\n",
" .with_transitions(\n",
" (\"ai_converse\", \"human_converse\", default),\n",
" (\"human_converse\", \"terminal\", expr(\"'exit' in question\")),\n",
" (\"human_converse\", \"ai_converse\", default),\n",
" )\n",
" # add identifiers that will help track the application\n",
" .with_identifiers(app_id=app_id, partition_key=\"sample_user\")\n",
" # initialize the state\n",
" .with_state(**initial_state)\n",
" # say what the initial action is\n",
" .with_entrypoint(\"human_converse\")\n",
" # add a hook to print the steps -- optional but shows that Burr is pluggable\n",
" .with_hooks(PrintStepHook())\n",
" # add tracking -- this will show up in the BURR UI.\n",
" .with_tracker(project=\"demo:conversational-rag\")\n",
" # build the application\n",
" .build()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bf5d1e084a791fa5",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:48:39.379712Z",
"start_time": "2024-03-26T20:48:38.701659Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 9.0.0 (20230911.1827)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"292pt\" height=\"194pt\"\n",
" viewBox=\"0.00 0.00 292.28 193.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 189.5)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-189.5 288.28,-189.5 288.28,4 -4,4\"/>\n",
"<!-- ai_converse -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>ai_converse</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M82.38,-185.5C82.38,-185.5 25.88,-185.5 25.88,-185.5 19.88,-185.5 13.88,-179.5 13.88,-173.5 13.88,-173.5 13.88,-161.5 13.88,-161.5 13.88,-155.5 19.88,-149.5 25.88,-149.5 25.88,-149.5 82.38,-149.5 82.38,-149.5 88.38,-149.5 94.38,-155.5 94.38,-161.5 94.38,-161.5 94.38,-173.5 94.38,-173.5 94.38,-179.5 88.38,-185.5 82.38,-185.5\"/>\n",
"<text text-anchor=\"middle\" x=\"54.12\" y=\"-162.45\" font-family=\"Times,serif\" font-size=\"14.00\">ai_converse</text>\n",
"</g>\n",
"<!-- human_converse -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>human_converse</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M96.25,-118.5C96.25,-118.5 12,-118.5 12,-118.5 6,-118.5 0,-112.5 0,-106.5 0,-106.5 0,-94.5 0,-94.5 0,-88.5 6,-82.5 12,-82.5 12,-82.5 96.25,-82.5 96.25,-82.5 102.25,-82.5 108.25,-88.5 108.25,-94.5 108.25,-94.5 108.25,-106.5 108.25,-106.5 108.25,-112.5 102.25,-118.5 96.25,-118.5\"/>\n",
"<text text-anchor=\"middle\" x=\"54.12\" y=\"-95.45\" font-family=\"Times,serif\" font-size=\"14.00\">human_converse</text>\n",
"</g>\n",
"<!-- ai_converse&#45;&gt;human_converse -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>ai_converse&#45;&gt;human_converse</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M48.07,-149.08C47.49,-143.25 47.27,-136.59 47.42,-130.14\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"50.89,-130.66 47.98,-120.48 43.9,-130.26 50.89,-130.66\"/>\n",
"</g>\n",
"<!-- human_converse&#45;&gt;ai_converse -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>human_converse&#45;&gt;ai_converse</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M60.19,-118.97C60.76,-124.8 60.98,-131.46 60.83,-137.92\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"57.36,-137.38 60.27,-147.57 64.34,-137.79 57.36,-137.38\"/>\n",
"</g>\n",
"<!-- terminal -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>terminal</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M73,-36C73,-36 35.25,-36 35.25,-36 29.25,-36 23.25,-30 23.25,-24 23.25,-24 23.25,-12 23.25,-12 23.25,-6 29.25,0 35.25,0 35.25,0 73,0 73,0 79,0 85,-6 85,-12 85,-12 85,-24 85,-24 85,-30 79,-36 73,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"54.12\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">terminal</text>\n",
"</g>\n",
"<!-- human_converse&#45;&gt;terminal -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>human_converse&#45;&gt;terminal</title>\n",
"<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M54.12,-82.03C54.12,-72.01 54.12,-59.18 54.12,-47.71\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"57.63,-47.95 54.13,-37.95 50.63,-47.95 57.63,-47.95\"/>\n",
"<text text-anchor=\"middle\" x=\"98.38\" y=\"-54.2\" font-family=\"Times,serif\" font-size=\"14.00\">&#39;exit&#39; in question</text>\n",
"</g>\n",
"<!-- input__user_question -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>input__user_question</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"198.12\" cy=\"-167.5\" rx=\"86.15\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"198.12\" y=\"-162.45\" font-family=\"Times,serif\" font-size=\"14.00\">input: user_question</text>\n",
"</g>\n",
"<!-- input__user_question&#45;&gt;human_converse -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>input__user_question&#45;&gt;human_converse</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M162.9,-150.6C144.9,-142.48 122.71,-132.46 103.08,-123.6\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"104.72,-120.5 94.16,-119.57 101.84,-126.88 104.72,-120.5\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x12ad5e2f0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# let's visualize what we have\n",
"app.visualize(include_conditions=True)"
]
},
{
"cell_type": "markdown",
"id": "430bab287b6ad9a",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Let's run the app. \n",
"\n",
"Let's run it a step at a time."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8bcfe9ca48f87618",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:50:38.782857Z",
"start_time": "2024-03-26T20:50:28.797204Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"⏳Processing input from user...\n",
"🎙💬 Who is Stefan? Please answer in English. \n",
"\n",
"Ran action human_converse with result:\n",
" {'question': 'Who is Stefan? Please answer in English.'} \n",
" and state:\n",
" {'__PRIOR_STEP': 'human_converse',\n",
" '__SEQUENCE_ID': 0,\n",
" 'chat_history': ['Human: Who is Stefan? Please answer in English.'],\n",
" 'input_texts': ['harrison worked at kensho',\n",
" 'stefan worked at Stitch Fix',\n",
" 'stefan likes tacos',\n",
" 'elijah worked at TwoSigma',\n",
" 'elijah likes mango',\n",
" 'stefan used to work at IBM',\n",
" 'elijah likes to go biking',\n",
" 'stefan likes to bake sourdough'],\n",
" 'question': 'Who is Stefan? Please answer in English.'}\n"
]
}
],
"source": [
"app.reset_to_entrypoint() # reset the app to the entrypoint\n",
"user_question = input(\"Ask something (or type exit to quit): \")\n",
"previous_action, result, state = app.step(\n",
" inputs={\"user_question\": user_question},\n",
")\n",
"print(f\"Ran action {previous_action.name} with result:\\n {pprint.pformat(result)} \\n and state:\\n {pprint.pformat(state.get_all())}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "81940578d58fd602",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:50:44.662755Z",
"start_time": "2024-03-26T20:50:41.919782Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🤔 AI is thinking...\n",
"🤖💬 Stefan is a person who used to work at IBM, worked at Stitch Fix, likes tacos, and likes to bake sourdough. \n",
"\n",
"Ran action ai_converse with result:\n",
" {'conversational_rag_response': 'Stefan is a person who used to work at IBM, '\n",
" 'worked at Stitch Fix, likes tacos, and likes '\n",
" 'to bake sourdough.'} \n",
" and state:\n",
" {'__PRIOR_STEP': 'ai_converse',\n",
" '__SEQUENCE_ID': 1,\n",
" 'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.'],\n",
" 'input_texts': ['harrison worked at kensho',\n",
" 'stefan worked at Stitch Fix',\n",
" 'stefan likes tacos',\n",
" 'elijah worked at TwoSigma',\n",
" 'elijah likes mango',\n",
" 'stefan used to work at IBM',\n",
" 'elijah likes to go biking',\n",
" 'stefan likes to bake sourdough'],\n",
" 'question': 'Who is Stefan? Please answer in English.'}\n"
]
}
],
"source": [
"# now let's run the AI conversational step\n",
"previous_action, result, state = app.step()\n",
"print(f\"Ran action {previous_action.name} with result:\\n {pprint.pformat(result)} \\n and state:\\n {pprint.pformat(state.get_all())}\")"
]
},
{
"cell_type": "markdown",
"id": "36ec2f4908c2dde2",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Let's now run the app to completion\n",
"\n",
"You could do the above for each action. Or you could tell the app to run until certain\n",
"actions/conditions are met."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "be6c573158b65cb1",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:52:21.364028Z",
"start_time": "2024-03-26T20:50:52.382808Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running RAG with initial state:\n",
" {'__PRIOR_STEP': 'ai_converse',\n",
" '__SEQUENCE_ID': 1,\n",
" 'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.'],\n",
" 'input_texts': ['harrison worked at kensho',\n",
" 'stefan worked at Stitch Fix',\n",
" 'stefan likes tacos',\n",
" 'elijah worked at TwoSigma',\n",
" 'elijah likes mango',\n",
" 'stefan used to work at IBM',\n",
" 'elijah likes to go biking',\n",
" 'stefan likes to bake sourdough'],\n",
" 'question': 'Who is Stefan? Please answer in English.'}\n",
"⏳Processing input from user...\n",
"🎙💬 where does Elijah work? \n",
"\n",
"🤔 AI is thinking...\n",
"🤖💬 Elijah works at TwoSigma. \n",
"⏳Processing input from user...\n",
"🎙💬 does he also like tacos? \n",
"\n",
"🤔 AI is thinking...\n",
"🤖💬 Based on the given context, we cannot determine if Elijah likes tacos, as it is only mentioned that he likes mango, enjoys biking, and worked at TwoSigma. There is no mention of Elijah's preference for tacos. \n",
"⏳Processing input from user...\n",
"🎙💬 where does Harrison work? \n",
"\n",
"🤔 AI is thinking...\n",
"🤖💬 Harrison works at Kensho. \n",
"⏳Processing input from user...\n",
"🎙💬 exit \n",
"\n",
"{'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.',\n",
" 'Human: where does Elijah work?',\n",
" 'AI: Elijah works at TwoSigma.',\n",
" 'Human: does he also like tacos?',\n",
" 'AI: Based on the given context, we cannot determine if '\n",
" 'Elijah likes tacos, as it is only mentioned that he likes '\n",
" 'mango, enjoys biking, and worked at TwoSigma. There is no '\n",
" \"mention of Elijah's preference for tacos.\",\n",
" 'Human: where does Harrison work?',\n",
" 'AI: Harrison works at Kensho.',\n",
" 'Human: exit']}\n"
]
}
],
"source": [
"print(f\"Running RAG with initial state:\\n {pprint.pformat(app.state.get_all())}\")\n",
"while True:\n",
" user_question = input(\"Ask something (or type exit to quit): \")\n",
" previous_action, result, state = app.run(\n",
" halt_before=[\"human_converse\"],\n",
" halt_after=[\"terminal\"],\n",
" inputs={\"user_question\": user_question},\n",
" )\n",
" if previous_action.name == \"terminal\":\n",
" # reached the end\n",
" pprint.pprint(result)\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "169946a65f977df9",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Reloading from prior state\n",
"\n",
"Burr makes it easy to reload from a prior state. In this example we'll just use what is logged to the tracker to \"go back in time\" and reload the application to that state. \n",
"\n",
"This is useful for debugging, building the application itself, etc.\n",
"\n",
"There are two ways to load prior state:\n",
"1. Load the state outside the Burr Application. i.e. pass it in as initial state.\n",
"2. Use the ApplicationBuilder .initialize_from() method.\n",
"\n",
"The difference between them, is that the first method is more flexible, allowing you to create\n",
"new \"app_ids\", i.e. traces. The second method will keep the same app_id, and thus allow you \n",
"\"pick up where you left off\", e.g. in the case of a crash, or if you wanted to start from \n",
"the last conversation with a user.\n",
"\n",
"Below we show how to do the first method. Then after that the second method, to show how\n",
"to pick up the prior conversation from where it left off. \n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "c7f4dd64f73ed2d8",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:52:37.419848Z",
"start_time": "2024-03-26T20:52:37.393869Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded state from app_id:85747466-1524-4708-a283-8e7faa67b8ed, sequence_id:7::\n",
" {'__SEQUENCE_ID': 7,\n",
" 'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.',\n",
" 'Human: where does Elijah work?',\n",
" 'AI: Elijah works at TwoSigma.',\n",
" 'Human: does he also like tacos?',\n",
" 'AI: Based on the given context, we cannot determine if '\n",
" 'Elijah likes tacos, as it is only mentioned that he likes '\n",
" 'mango, enjoys biking, and worked at TwoSigma. There is no '\n",
" \"mention of Elijah's preference for tacos.\",\n",
" 'Human: where does Harrison work?',\n",
" 'AI: Harrison works at Kensho.'],\n",
" 'input_texts': ['harrison worked at kensho',\n",
" 'stefan worked at Stitch Fix',\n",
" 'stefan likes tacos',\n",
" 'elijah worked at TwoSigma',\n",
" 'elijah likes mango',\n",
" 'stefan used to work at IBM',\n",
" 'elijah likes to go biking',\n",
" 'stefan likes to bake sourdough'],\n",
" 'question': 'where does Harrison work?'}\n"
]
}
],
"source": [
"# set up for rewinding to a prior state -- loading it in as initial state\n",
"prior_app_id = app_id\n",
"last_sequence_id = app.sequence_id\n",
"rewind_to_sequence_id = last_sequence_id - 2\n",
"new_app_id = str(uuid.uuid4())\n",
"\n",
"project_name = \"demo:conversational-rag\"\n",
"# we use the tracking client here to get the state of the application at a prior sequence_id\n",
"tracker = LocalTrackingClient(project=project_name)\n",
"persisted_state = tracker.load(partition_key=\"sample_user\", \n",
" app_id=prior_app_id, \n",
" sequence_id=rewind_to_sequence_id)\n",
"state_values = persisted_state['state'].get_all()\n",
"print(f\"Loaded state from app_id:{prior_app_id}, \"\n",
" f\"sequence_id:{rewind_to_sequence_id}::\\n \"\n",
" f\"{pprint.pformat(state_values)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ee618e3-15c0-403b-bc96-3a2faaea457e",
"metadata": {},
"outputs": [],
"source": [
"other_app = (\n",
" ApplicationBuilder()\n",
" # add the actions\n",
" .with_actions(\n",
" # bind the vector store to the AI conversational step\n",
" ai_converse=ai_converse.bind(vector_store=vector_store),\n",
" human_converse=human_converse,\n",
" terminal=burr.core.Result(\"chat_history\"),\n",
" )\n",
" # set the transitions between actions\n",
" .with_transitions(\n",
" (\"ai_converse\", \"human_converse\", default),\n",
" (\"human_converse\", \"terminal\", expr(\"'exit' in question\")),\n",
" (\"human_converse\", \"ai_converse\", default),\n",
" )\n",
" # add identifiers that will help track the application\n",
" .with_identifiers(app_id=new_app_id, partition_key=\"sample_user\")\n",
" # set state to prior state\n",
" .with_state(**persisted_state[\"state\"].get_all())\n",
" # say where we want to start\n",
" .with_entrypoint(\"human_converse\")\n",
" # add a hook to print the steps -- optional but shows that Burr is pluggable\n",
" .with_hooks(PrintStepHook())\n",
" # add tracking -- this will show up in the BURR UI.\n",
" .with_tracker(tracker)\n",
" # build the application\n",
" .build()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "34140c5864b940dc",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T20:54:24.035153Z",
"start_time": "2024-03-26T20:53:19.237522Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running RAG with loaded state:\n",
" {'__SEQUENCE_ID': 7,\n",
" 'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.',\n",
" 'Human: where does Elijah work?',\n",
" 'AI: Elijah works at TwoSigma.',\n",
" 'Human: does he also like tacos?',\n",
" 'AI: Based on the given context, we cannot determine if '\n",
" 'Elijah likes tacos, as it is only mentioned that he likes '\n",
" 'mango, enjoys biking, and worked at TwoSigma. There is no '\n",
" \"mention of Elijah's preference for tacos.\",\n",
" 'Human: where does Harrison work?',\n",
" 'AI: Harrison works at Kensho.'],\n",
" 'input_texts': ['harrison worked at kensho',\n",
" 'stefan worked at Stitch Fix',\n",
" 'stefan likes tacos',\n",
" 'elijah worked at TwoSigma',\n",
" 'elijah likes mango',\n",
" 'stefan used to work at IBM',\n",
" 'elijah likes to go biking',\n",
" 'stefan likes to bake sourdough'],\n",
" 'question': 'where does Harrison work?'}\n",
"⏳Processing input from user...\n",
"🎙💬 does Harrison like pizza? \n",
"\n",
"🤔 AI is thinking...\n",
"🤖💬 Based on the context given, we do not have any information about whether Harrison likes pizza or not. \n",
"⏳Processing input from user...\n",
"🎙💬 I am going to a mexican restaurant. Who should I take with me? \n",
"\n",
"🤔 AI is thinking...\n",
"🤖💬 You should take Stefan with you to the Mexican restaurant. \n",
"⏳Processing input from user...\n",
"🎙💬 exit \n",
"\n",
"{'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.',\n",
" 'Human: where does Elijah work?',\n",
" 'AI: Elijah works at TwoSigma.',\n",
" 'Human: does he also like tacos?',\n",
" 'AI: Based on the given context, we cannot determine if '\n",
" 'Elijah likes tacos, as it is only mentioned that he likes '\n",
" 'mango, enjoys biking, and worked at TwoSigma. There is no '\n",
" \"mention of Elijah's preference for tacos.\",\n",
" 'Human: where does Harrison work?',\n",
" 'AI: Harrison works at Kensho.',\n",
" 'Human: does Harrison like pizza?',\n",
" 'AI: Based on the context given, we do not have any '\n",
" 'information about whether Harrison likes pizza or not.',\n",
" 'Human: I am going to a mexican restaurant. Who should I '\n",
" 'take with me?',\n",
" 'AI: You should take Stefan with you to the Mexican '\n",
" 'restaurant.',\n",
" 'Human: exit']}\n"
]
}
],
"source": [
"# We can now change test, debug, etc. from this prior state.\n",
"print(f\"Running RAG with loaded state:\\n {pprint.pformat(state_values)}\")\n",
"while True:\n",
" user_question = input(\"Ask something (or type exit to quit): \")\n",
" previous_action, result, state = other_app.run(\n",
" halt_before=[\"human_converse\"],\n",
" halt_after=[\"terminal\"],\n",
" inputs={\"user_question\": user_question},\n",
" )\n",
" if previous_action and previous_action.name == \"terminal\":\n",
" # reached the end\n",
" pprint.pprint(result)\n",
" break\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "fc62a033644c7b80",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T21:04:56.748649Z",
"start_time": "2024-03-26T21:04:56.742019Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Now let's show how to use the ApplicationBuilder.initialize_from() method to pick up where we left off.\n",
"# This is useful if you want to continue a conversation with a user, or if you had a crash, etc.\n",
"\n",
"# set up for rewinding to a prior state -- loading it in as initial state\n",
"prior_app_id = app_id\n",
"new_app_id = str(uuid.uuid4())\n",
"\n",
"project_name = \"demo:conversational-rag\"\n",
"# we use the tracking client here to get the state of the application at a prior sequence_id\n",
"tracker = LocalTrackingClient(project=project_name)\n",
"pick_up_where_we_left_off_app = (\n",
" ApplicationBuilder()\n",
" # add the actions\n",
" .with_actions(\n",
" # bind the vector store to the AI conversational step\n",
" ai_converse=ai_converse.bind(vector_store=vector_store),\n",
" human_converse=human_converse,\n",
" terminal=burr.core.Result(\"chat_history\"),\n",
" )\n",
" # set the transitions between actions\n",
" .with_transitions(\n",
" (\"ai_converse\", \"human_converse\", default),\n",
" (\"human_converse\", \"terminal\", expr(\"'exit' in question\")),\n",
" (\"human_converse\", \"ai_converse\", default),\n",
" )\n",
" # add identifiers that will help track the application\n",
" .with_identifiers(app_id=prior_app_id, partition_key=\"sample_user\")\n",
" .initialize_from(\n",
" initializer=tracker,\n",
" resume_at_next_action=False, # we want to always start at human_converse; our entrypoint\n",
" default_entrypoint=\"human_converse\",\n",
" default_state=initial_state, # set some default state incase we can't find the prior state\n",
" )\n",
" # add a hook to print the steps -- optional but shows that Burr is pluggable\n",
" .with_hooks(PrintStepHook())\n",
" # add tracking -- this will show up in the BURR UI.\n",
" .with_tracker(tracker)\n",
" # build the application\n",
" .build()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "b6d23d6d6a6643d0",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-26T21:05:41.246005Z",
"start_time": "2024-03-26T21:05:02.855430Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running RAG with loaded state:\n",
" {'__PRIOR_STEP': 'terminal',\n",
" '__SEQUENCE_ID': 9,\n",
" 'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.',\n",
" 'Human: where does Elijah work?',\n",
" 'AI: Elijah works at TwoSigma.',\n",
" 'Human: does he also like tacos?',\n",
" 'AI: Based on the given context, we cannot determine if '\n",
" 'Elijah likes tacos, as it is only mentioned that he likes '\n",
" 'mango, enjoys biking, and worked at TwoSigma. There is no '\n",
" \"mention of Elijah's preference for tacos.\",\n",
" 'Human: where does Harrison work?',\n",
" 'AI: Harrison works at Kensho.',\n",
" 'Human: exit'],\n",
" 'input_texts': ['harrison worked at kensho',\n",
" 'stefan worked at Stitch Fix',\n",
" 'stefan likes tacos',\n",
" 'elijah worked at TwoSigma',\n",
" 'elijah likes mango',\n",
" 'stefan used to work at IBM',\n",
" 'elijah likes to go biking',\n",
" 'stefan likes to bake sourdough'],\n",
" 'question': 'exit'}\n",
"⏳Processing input from user...\n",
"🎙💬 who would most likely enjoy a fruit salad? \n",
"\n",
"🤔 AI is thinking...\n",
"🤖💬 Elijah would most likely enjoy a fruit salad since he is known to like mango. \n",
"⏳Processing input from user...\n",
"🎙💬 who would be the most helpful in terms of financial advice? \n",
"\n",
"🤔 AI is thinking...\n",
"🤖💬 Elijah would be the most helpful person for financial advice since he worked at TwoSigma, a financial services company. \n",
"⏳Processing input from user...\n",
"🎙💬 exit \n",
"\n",
"{'chat_history': ['Human: Who is Stefan? Please answer in English.',\n",
" 'AI: Stefan is a person who used to work at IBM, worked at '\n",
" 'Stitch Fix, likes tacos, and likes to bake sourdough.',\n",
" 'Human: where does Elijah work?',\n",
" 'AI: Elijah works at TwoSigma.',\n",
" 'Human: does he also like tacos?',\n",
" 'AI: Based on the given context, we cannot determine if '\n",
" 'Elijah likes tacos, as it is only mentioned that he likes '\n",
" 'mango, enjoys biking, and worked at TwoSigma. There is no '\n",
" \"mention of Elijah's preference for tacos.\",\n",
" 'Human: where does Harrison work?',\n",
" 'AI: Harrison works at Kensho.',\n",
" 'Human: exit',\n",
" 'Human: who would most likely enjoy a fruit salad?',\n",
" 'AI: Elijah would most likely enjoy a fruit salad since he '\n",
" 'is known to like mango.',\n",
" 'Human: who would be the most helpful in terms of financial '\n",
" 'advice?',\n",
" 'AI: Elijah would be the most helpful person for financial '\n",
" 'advice since he worked at TwoSigma, a financial services '\n",
" 'company.',\n",
" 'Human: exit']}\n"
]
}
],
"source": [
"print(f\"Running RAG with loaded state:\\n {pprint.pformat(app.state.get_all())}\")\n",
"while True:\n",
" user_question = input(\"Ask something (or type exit to quit): \")\n",
" previous_action, result, state = pick_up_where_we_left_off_app.run(\n",
" halt_before=[\"human_converse\"],\n",
" halt_after=[\"terminal\"],\n",
" inputs={\"user_question\": user_question},\n",
" )\n",
" if previous_action and previous_action.name == \"terminal\":\n",
" # reached the end\n",
" pprint.pprint(result)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe401e84026db9bb",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}