| { |
| "cells": [ |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "28c7f468-79e9-48e6-9e6f-51dac6116fb0", |
| "metadata": {}, |
| "source": [ |
| "# install requirements\n", |
| "!pip install falkordb openai burr[graphviz] " |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "5a2c326f-53cc-4eb1-8508-e06584d8051c", |
| "metadata": {}, |
| "source": [ |
| "# Question & answer notebook\n", |
| "\n", |
| "This notebook walks you through how to build a burr application that talks to falkorDB and openai to answer questions about UFC fights." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "id": "39233706-20fd-4043-bc0f-291626664f08", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:45:35.743153Z", |
| "start_time": "2024-05-27T16:45:35.716206Z" |
| } |
| }, |
| "source": [ |
| "# import what we need\n", |
| "import json\n", |
| "from typing import Tuple\n", |
| "\n", |
| "import openai\n", |
| "from burr.core import ApplicationBuilder, State, default, expr, Application\n", |
| "from burr.core.action import action\n", |
| "from burr.tracking import LocalTrackingClient\n", |
| "import uuid\n", |
| "from falkordb import FalkorDB\n", |
| "from graph_schema import graph_schema\n", |
| "import falkordb" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a2fe5f83-75b4-4d44-a80e-829947b6d240", |
| "metadata": {}, |
| "source": [ |
| "## Helper functions\n", |
| "We first set up some helper functions that we'll use." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "ef809095-566a-40af-a8a6-cbf72ef21227", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:45:40.907749Z", |
| "start_time": "2024-05-27T16:45:40.894703Z" |
| } |
| }, |
| "source": [ |
| "def schema_to_prompt(schema):\n", |
| " \"\"\"Prompt to help tell the LLM what is in the graph DB\"\"\"\n", |
| " prompt = \"The Knowledge graph contains nodes of the following types:\\n\"\n", |
| "\n", |
| " for node in schema['nodes']:\n", |
| " lbl = node\n", |
| " node = schema['nodes'][node]\n", |
| " if len(node['attributes']) > 0:\n", |
| " prompt += f\"The {lbl} node type has the following set of attributes:\\n\"\n", |
| " for attr in node['attributes']:\n", |
| " t = node['attributes'][attr]['type']\n", |
| " prompt += f\"The {attr} attribute is of type {t}\\n\"\n", |
| " else:\n", |
| " prompt += f\"The {node} node type has no attributes:\\n\"\n", |
| "\n", |
| " prompt += \"In addition the Knowledge graph contains edge of the following types:\\n\"\n", |
| "\n", |
| " for edge in schema['edges']:\n", |
| " rel = edge\n", |
| " edge = schema['edges'][edge]\n", |
| " if len(edge['attributes']) > 0:\n", |
| " prompt += f\"The {rel} edge type has the following set of attributes:\\n\"\n", |
| " for attr in edge['attributes']:\n", |
| " t = edge['attributes'][attr]['type']\n", |
| " prompt += f\"The {attr} attribute is of type {t}\\n\"\n", |
| " else:\n", |
| " prompt += f\"The {rel} edge type has no attributes:\\n\"\n", |
| "\n", |
| " prompt += f\"The {rel} edge connects the following entities:\\n\"\n", |
| " for conn in edge['connects']:\n", |
| " src = conn[0]\n", |
| " dest = conn[1]\n", |
| " prompt += f\"{src} is connected via {rel} to {dest}, (:{src})-[:{rel}]->(:{dest})\\n\"\n", |
| "\n", |
| " return prompt\n", |
| "\n", |
| "def set_inital_chat_history(schema_prompt: str) -> list[dict]:\n", |
| " \"\"\"Helper to set initial system message\"\"\"\n", |
| " SYSTEM_MESSAGE = \"You are a Cypher expert with access to a directed knowledge graph\\n\"\n", |
| " SYSTEM_MESSAGE += schema_prompt\n", |
| " SYSTEM_MESSAGE += (\"Query the knowledge graph to extract relevant information to help you anwser the users \"\n", |
| " \"questions, base your answer only on the context retrieved from the knowledge graph, \"\n", |
| " \"do not use preexisting knowledge.\")\n", |
| " SYSTEM_MESSAGE += (\"For example to find out if two fighters had fought each other e.g. did Conor McGregor \"\n", |
| " \"every compete against Jose Aldo issue the following query: \"\n", |
| " \"MATCH (a:Fighter)-[]->(f:Fight)<-[]-(b:Fighter) WHERE a.Name = 'Conor McGregor' AND \"\n", |
| " \"b.Name = 'Jose Aldo' RETURN a, b\\n\")\n", |
| "\n", |
| " messages = [{\"role\": \"system\", \"content\": SYSTEM_MESSAGE}]\n", |
| " return messages" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "9a78e2eb1ab4c1ef", |
| "metadata": {}, |
| "source": [ |
| "## Tools\n", |
| "Here we describe the tool openAI will use & it's schema that will be passed to describe it.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "7134877d85934e9d", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:45:54.998999Z", |
| "start_time": "2024-05-27T16:45:54.991226Z" |
| } |
| }, |
| "source": [ |
| "def run_cypher_query(graph, query):\n", |
| " \"\"\"What executes a query against falkorDB\"\"\"\n", |
| " try:\n", |
| " results = graph.ro_query(query).result_set\n", |
| " except:\n", |
| " results = {\"error\": \"Query failed please try a different variation of this query\"}\n", |
| "\n", |
| " if len(results) == 0:\n", |
| " results = {\n", |
| " \"error\": \"The query did not return any data, please make sure you're using the right edge \"\n", |
| " \"directions and you're following the correct graph schema\"}\n", |
| "\n", |
| " return str(results)\n", |
| "\n", |
| "# description\n", |
| "run_cypher_query_tool_description = {\n", |
| " \"type\": \"function\",\n", |
| " \"function\": {\n", |
| " \"name\": \"run_cypher_query\",\n", |
| " \"description\": \"Runs a Cypher query against the knowledge graph\",\n", |
| " \"parameters\": {\n", |
| " \"type\": \"object\",\n", |
| " \"properties\": {\n", |
| " \"query\": {\n", |
| " \"type\": \"string\",\n", |
| " \"description\": \"Query to execute\",\n", |
| " },\n", |
| " },\n", |
| " \"required\": [\"query\"],\n", |
| " },\n", |
| " },\n", |
| "}" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "8f1987f6af0eb2b6", |
| "metadata": {}, |
| "source": [ |
| "## Actions\n", |
| "Let's now define the actions that our application will make and what they read from & write to with respect to state.\n", |
| "\n", |
| "We'll define four of them:\n", |
| "\n", |
| "1. Human converse: This action will take the user's question and store it in the state.\n", |
| "2. AI create cypher query: This action will use the user's question to create a cypher query.\n", |
| "3. Tool call: This action will execute the cypher query and append the result to the chat history.\n", |
| "4. AI response: This action will take the result of the cypher query and create a response." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "6c1e5d3a8bc04ec6", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:46:20.736830Z", |
| "start_time": "2024-05-27T16:46:20.729995Z" |
| } |
| }, |
| "source": [ |
| "@action(\n", |
| " reads=[],\n", |
| " writes=[\"question\", \"chat_history\"],\n", |
| ")\n", |
| "def human_converse(state: State, user_question: str) -> Tuple[dict, State]:\n", |
| " \"\"\"Human converse step -- make sure we get input, and store it as state.\"\"\"\n", |
| " new_state = state.update(question=user_question)\n", |
| " new_state = new_state.append(chat_history={\"role\": \"user\", \"content\": user_question})\n", |
| " return {\"question\": user_question}, new_state" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "id": "a29aeb9b7b025591", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:46:24.286410Z", |
| "start_time": "2024-05-27T16:46:24.277468Z" |
| } |
| }, |
| "source": [ |
| "@action(\n", |
| " reads=[\"question\", \"chat_history\"],\n", |
| " writes=[\"chat_history\", \"tool_calls\"],\n", |
| ")\n", |
| "def AI_create_cypher_query(state: State, client: openai.Client) -> tuple[dict, State]:\n", |
| " \"\"\"AI step to create the cypher query.\"\"\"\n", |
| " messages = state[\"chat_history\"]\n", |
| " # Call the function\n", |
| " response = client.chat.completions.create(\n", |
| " model=\"gpt-4-turbo-preview\",\n", |
| " messages=messages,\n", |
| " tools=[run_cypher_query_tool_description],\n", |
| " tool_choice=\"auto\",\n", |
| " )\n", |
| " response_message = response.choices[0].message\n", |
| " new_state = state.append(chat_history=response_message.to_dict())\n", |
| " tool_calls = response_message.tool_calls\n", |
| " if tool_calls:\n", |
| " new_state = new_state.update(tool_calls=tool_calls)\n", |
| " # if there are no tool calls -- it means we didn't know what to do\n", |
| " return {\"ai_response\": response_message.content, \"usage\": response.usage.to_dict()}, new_state\n" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "a945e875be76be26", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:46:28.364022Z", |
| "start_time": "2024-05-27T16:46:28.347327Z" |
| } |
| }, |
| "source": [ |
| "@action(\n", |
| " reads=[\"tool_calls\", \"chat_history\"],\n", |
| " writes=[\"tool_calls\", \"chat_history\"],\n", |
| ")\n", |
| "def tool_call(state: State, graph: falkordb.Graph) -> Tuple[dict, State]:\n", |
| " \"\"\"Tool call step -- execute the query and append to chat history.\"\"\"\n", |
| " tool_calls = state.get(\"tool_calls\", [])\n", |
| " new_state = state\n", |
| " result = {\"tool_calls\": []}\n", |
| " for tool_call in tool_calls:\n", |
| " function_name = tool_call.function.name\n", |
| " assert (function_name == \"run_cypher_query\")\n", |
| " function_args = json.loads(tool_call.function.arguments)\n", |
| " function_response = run_cypher_query(graph, function_args.get(\"query\"))\n", |
| " new_state = new_state.append(chat_history=\n", |
| " {\n", |
| " \"tool_call_id\": tool_call.id,\n", |
| " \"role\": \"tool\",\n", |
| " \"name\": function_name,\n", |
| " \"content\": function_response,\n", |
| " }\n", |
| " )\n", |
| " result[\"tool_calls\"].append({\"tool_call_id\": tool_call.id, \"response\": function_response})\n", |
| " new_state = new_state.update(tool_calls=[])\n", |
| " return result, new_state" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "698d1ef5da45b2c6", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:18.853441Z", |
| "start_time": "2024-05-27T16:47:18.847306Z" |
| } |
| }, |
| "source": [ |
| "@action(\n", |
| " reads=[\"chat_history\"],\n", |
| " writes=[\"chat_history\"],\n", |
| ")\n", |
| "def AI_generate_response(state: State, client: openai.Client) -> tuple[dict, State]:\n", |
| " \"\"\"AI step to generate the response given the current chat history.\"\"\"\n", |
| " messages = state[\"chat_history\"]\n", |
| " response = client.chat.completions.create(\n", |
| " model=\"gpt-4-turbo-preview\",\n", |
| " messages=messages,\n", |
| " ) # get a new response from the model where it can see the function response\n", |
| " response_message = response.choices[0].message\n", |
| " new_state = state.append(chat_history=response_message.to_dict())\n", |
| " return {\"ai_response\": response_message.content,\n", |
| " \"usage\": response.usage.to_dict()}, new_state\n" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "9fca50d97f2aca0f", |
| "metadata": {}, |
| "source": [ |
| "## Define the application\n", |
| "This is where we define our application now" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "id": "66411a5074f15a7f", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:34.643728Z", |
| "start_time": "2024-05-27T16:47:34.515076Z" |
| } |
| }, |
| "source": [ |
| "# define our clients / connections / IDs\n", |
| "openai_client = openai.OpenAI()\n", |
| "db_client = FalkorDB(host='localhost', port=6379)\n", |
| "graph_name = \"UFC\"\n", |
| "application_run_id = str(uuid.uuid4())" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "e4d89ba2d477efb1", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:43.535537Z", |
| "start_time": "2024-05-27T16:47:38.440468Z" |
| } |
| }, |
| "source": [ |
| "# get the graph\n", |
| "graph = db_client.select_graph(graph_name)\n", |
| "# get schema\n", |
| "schema = graph_schema(graph)\n", |
| "# create a prompt from it\n", |
| "schema_prompt = schema_to_prompt(schema)\n", |
| "# set the initial chat history\n", |
| "base_messages = set_inital_chat_history(schema_prompt)" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "id": "dbe987ade3c5b69f", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:49:23.058742Z", |
| "start_time": "2024-05-27T16:49:23.052222Z" |
| } |
| }, |
| "source": [ |
| "tracker = LocalTrackingClient(\"ufc-falkor\")\n", |
| "# create graph\n", |
| "burr_application = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions( # define the actions\n", |
| " AI_create_cypher_query.bind(client=openai_client),\n", |
| " tool_call.bind(graph=graph),\n", |
| " AI_generate_response.bind(client=openai_client),\n", |
| " human_converse\n", |
| " )\n", |
| " .with_transitions( # define the edges between the actions based on state conditions\n", |
| " (\"human_converse\", \"AI_create_cypher_query\", default),\n", |
| " (\"AI_create_cypher_query\", \"tool_call\", expr(\"len(tool_calls)>0\")),\n", |
| " (\"AI_create_cypher_query\", \"human_converse\", default),\n", |
| " (\"tool_call\", \"AI_generate_response\", default),\n", |
| " (\"AI_generate_response\", \"human_converse\", default)\n", |
| " )\n", |
| " .with_identifiers(app_id=application_run_id)\n", |
| " .with_state( # initial state\n", |
| " **{\"chat_history\": base_messages, \"tool_calls\": []},\n", |
| " )\n", |
| " .with_entrypoint(\"human_converse\")\n", |
| " .with_tracker(tracker)\n", |
| " .build()\n", |
| ")" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "id": "5bb8432243e49d35", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:47.058569Z", |
| "start_time": "2024-05-27T16:47:46.295042Z" |
| } |
| }, |
| "source": [ |
| "burr_application.visualize(include_conditions=True)" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3ea2b29306cbb4d1", |
| "metadata": {}, |
| "source": [ |
| "## Run the application\n", |
| "Here we show how to do a simple loop stopping before `human_converse` each time to get user input before running the graph again.\n", |
| "\n", |
| "\n", |
| "### Viewing a trace of the this application in the Burr UI\n", |
| "Note: you can view the logs of the conversation in the Burr UI. \n", |
| "\n", |
| "To see that, in another terminal do:\n", |
| "\n", |
| "> burr\n", |
| "\n", |
| "You'll then have the UI running on [http://localhost:7241/](http://localhost:7241/).\n", |
| "\n", |
| "#### Using the Burr UI in google collab\n", |
| "To use the UI in google collab do the following:\n", |
| "\n", |
| "1. Run this in a cell\n", |
| "```python\n", |
| "from google.colab import output\n", |
| "output.serve_kernel_port_as_window(7241)\n", |
| "```\n", |
| "\n", |
| "2. Then start the burr UI:\n", |
| "```\n", |
| "!burr &\n", |
| "```\n", |
| "3. Click the link in (1) to open a new tab." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "id": "c2c996de4588512c", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:48:59.800468Z", |
| "start_time": "2024-05-27T16:48:02.100003Z" |
| } |
| }, |
| "source": [ |
| "# run it\n", |
| "while True:\n", |
| " # this will ask for input:\n", |
| " question = input(\"What can I help you with?\\n\")\n", |
| " if question == \"exit\":\n", |
| " break\n", |
| " current_action, _, current_state = burr_application.run(\n", |
| " halt_before=[\"human_converse\"],\n", |
| " inputs={\"user_question\": question},\n", |
| " )\n", |
| " # we'll then see the AI response:\n", |
| " print(f\"AI: {current_state['chat_history'][-1]['content']}\\n\")\n", |
| "current_state" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "05d0b159-3fd8-48b1-9c15-88375a52d60b", |
| "metadata": {}, |
| "source": [ |
| "With Burr we can continue where we left off easily!\n", |
| "\n", |
| "So why not run the conversation through some more?" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "id": "6366fdb4-5b08-4a8f-8f73-ffca67d2091a", |
| "metadata": {}, |
| "source": [ |
| "# run it\n", |
| "while True:\n", |
| " # this will ask for input:\n", |
| " question = input(\"What can I help you with?\\n\")\n", |
| " if question == \"exit\":\n", |
| " break\n", |
| " current_action, _, current_state = burr_application.run(\n", |
| " halt_before=[\"human_converse\"],\n", |
| " inputs={\"user_question\": question},\n", |
| " )\n", |
| " # we'll then see the AI response:\n", |
| " print(f\"AI: {current_state['chat_history'][-1]['content']}\\n\")\n", |
| "current_state" |
| ], |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "d98761be-7aa2-4378-8367-b084491ee8b1", |
| "metadata": {}, |
| "source": [], |
| "outputs": [] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.10.13" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |