| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "7fb27b941602401d91542211134fc71a", |
| "metadata": {}, |
| "source": [ |
| "Licensed to the Apache Software Foundation (ASF) under one\nor more contributor license agreements. See the NOTICE file\ndistributed with this work for additional information\nregarding copyright ownership. The ASF licenses this file\nto you under the Apache License, Version 2.0 (the\n\"License\"); you may not use this file except in compliance\nwith the License. You may obtain a copy of the License at\n\n http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing,\nsoftware distributed under the License is distributed on an\n\"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\nKIND, either express or implied. See the License for the\nspecific language governing permissions and limitations\nunder the License." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "43b8d11d", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Execute this cell to install dependencies\n", |
| "%pip install falkordb openai burr[graphviz]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "cbd976a6", |
| "metadata": {}, |
| "source": [ |
| "# Question & answer notebook [](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/LLM_Workflows/GraphRAG/notebook.ipynb) [](https://github.com/apache/hamilton/blob/main/examples/LLM_Workflows/GraphRAG/notebook.ipynb)\n", |
| "\n", |
| "\n", |
| "This notebook walks you through how to build a burr application that talks to falkorDB and openai to answer questions about UFC fights." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "id": "8eaefea7", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:45:35.743153Z", |
| "start_time": "2024-05-27T16:45:35.716206Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "# import what we need\n", |
| "import json\n", |
| "import uuid\n", |
| "\n", |
| "import falkordb\n", |
| "import openai\n", |
| "from burr.core import ApplicationBuilder, State, default, expr\n", |
| "from burr.core.action import action\n", |
| "from burr.tracking import LocalTrackingClient\n", |
| "from falkordb import FalkorDB\n", |
| "from graph_schema import graph_schema" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "86797dfa", |
| "metadata": {}, |
| "source": [ |
| "## Helper functions\n", |
| "We first set up some helper functions that we'll use." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "a32f2019", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:45:40.907749Z", |
| "start_time": "2024-05-27T16:45:40.894703Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "def schema_to_prompt(schema):\n", |
| " \"\"\"Prompt to help tell the LLM what is in the graph DB\"\"\"\n", |
| " prompt = \"The Knowledge graph contains nodes of the following types:\\n\"\n", |
| "\n", |
| " for node in schema[\"nodes\"]:\n", |
| " lbl = node\n", |
| " node = schema[\"nodes\"][node]\n", |
| " if len(node[\"attributes\"]) > 0:\n", |
| " prompt += f\"The {lbl} node type has the following set of attributes:\\n\"\n", |
| " for attr in node[\"attributes\"]:\n", |
| " t = node[\"attributes\"][attr][\"type\"]\n", |
| " prompt += f\"The {attr} attribute is of type {t}\\n\"\n", |
| " else:\n", |
| " prompt += f\"The {node} node type has no attributes:\\n\"\n", |
| "\n", |
| " prompt += \"In addition the Knowledge graph contains edge of the following types:\\n\"\n", |
| "\n", |
| " for edge in schema[\"edges\"]:\n", |
| " rel = edge\n", |
| " edge = schema[\"edges\"][edge]\n", |
| " if len(edge[\"attributes\"]) > 0:\n", |
| " prompt += f\"The {rel} edge type has the following set of attributes:\\n\"\n", |
| " for attr in edge[\"attributes\"]:\n", |
| " t = edge[\"attributes\"][attr][\"type\"]\n", |
| " prompt += f\"The {attr} attribute is of type {t}\\n\"\n", |
| " else:\n", |
| " prompt += f\"The {rel} edge type has no attributes:\\n\"\n", |
| "\n", |
| " prompt += f\"The {rel} edge connects the following entities:\\n\"\n", |
| " for conn in edge[\"connects\"]:\n", |
| " src = conn[0]\n", |
| " dest = conn[1]\n", |
| " prompt += f\"{src} is connected via {rel} to {dest}, (:{src})-[:{rel}]->(:{dest})\\n\"\n", |
| "\n", |
| " return prompt\n", |
| "\n", |
| "\n", |
| "def set_inital_chat_history(schema_prompt: str) -> list[dict]:\n", |
| " \"\"\"Helper to set initial system message\"\"\"\n", |
| " SYSTEM_MESSAGE = \"You are a Cypher expert with access to a directed knowledge graph\\n\"\n", |
| " SYSTEM_MESSAGE += schema_prompt\n", |
| " SYSTEM_MESSAGE += (\n", |
| " \"Query the knowledge graph to extract relevant information to help you anwser the users \"\n", |
| " \"questions, base your answer only on the context retrieved from the knowledge graph, \"\n", |
| " \"do not use preexisting knowledge.\"\n", |
| " )\n", |
| " SYSTEM_MESSAGE += (\n", |
| " \"For example to find out if two fighters had fought each other e.g. did Conor McGregor \"\n", |
| " \"every compete against Jose Aldo issue the following query: \"\n", |
| " \"MATCH (a:Fighter)-[]->(f:Fight)<-[]-(b:Fighter) WHERE a.Name = 'Conor McGregor' AND \"\n", |
| " \"b.Name = 'Jose Aldo' RETURN a, b\\n\"\n", |
| " )\n", |
| "\n", |
| " messages = [{\"role\": \"system\", \"content\": SYSTEM_MESSAGE}]\n", |
| " return messages" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "13fd36dc", |
| "metadata": {}, |
| "source": [ |
| "## Tools\n", |
| "Here we describe the tool openAI will use & it's schema that will be passed to describe it.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "265b9784", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:45:54.998999Z", |
| "start_time": "2024-05-27T16:45:54.991226Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "def run_cypher_query(graph, query):\n", |
| " \"\"\"What executes a query against falkorDB\"\"\"\n", |
| " try:\n", |
| " results = graph.ro_query(query).result_set\n", |
| " except Exception:\n", |
| " results = {\"error\": \"Query failed please try a different variation of this query\"}\n", |
| "\n", |
| " if len(results) == 0:\n", |
| " results = {\n", |
| " \"error\": \"The query did not return any data, please make sure you're using the right edge \"\n", |
| " \"directions and you're following the correct graph schema\"\n", |
| " }\n", |
| "\n", |
| " return str(results)\n", |
| "\n", |
| "\n", |
| "# description\n", |
| "run_cypher_query_tool_description = {\n", |
| " \"type\": \"function\",\n", |
| " \"function\": {\n", |
| " \"name\": \"run_cypher_query\",\n", |
| " \"description\": \"Runs a Cypher query against the knowledge graph\",\n", |
| " \"parameters\": {\n", |
| " \"type\": \"object\",\n", |
| " \"properties\": {\n", |
| " \"query\": {\n", |
| " \"type\": \"string\",\n", |
| " \"description\": \"Query to execute\",\n", |
| " },\n", |
| " },\n", |
| " \"required\": [\"query\"],\n", |
| " },\n", |
| " },\n", |
| "}" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "549ab550", |
| "metadata": {}, |
| "source": [ |
| "## Actions\n", |
| "Let's now define the actions that our application will make and what they read from & write to with respect to state.\n", |
| "\n", |
| "We'll define four of them:\n", |
| "\n", |
| "1. Human converse: This action will take the user's question and store it in the state.\n", |
| "2. AI create cypher query: This action will use the user's question to create a cypher query.\n", |
| "3. Tool call: This action will execute the cypher query and append the result to the chat history.\n", |
| "4. AI response: This action will take the result of the cypher query and create a response." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "c08907c3", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:46:20.736830Z", |
| "start_time": "2024-05-27T16:46:20.729995Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "@action(\n", |
| " reads=[],\n", |
| " writes=[\"question\", \"chat_history\"],\n", |
| ")\n", |
| "def human_converse(state: State, user_question: str) -> tuple[dict, State]:\n", |
| " \"\"\"Human converse step -- make sure we get input, and store it as state.\"\"\"\n", |
| " new_state = state.update(question=user_question)\n", |
| " new_state = new_state.append(chat_history={\"role\": \"user\", \"content\": user_question})\n", |
| " return {\"question\": user_question}, new_state" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "id": "4aa87cfe", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:46:24.286410Z", |
| "start_time": "2024-05-27T16:46:24.277468Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "@action(\n", |
| " reads=[\"question\", \"chat_history\"],\n", |
| " writes=[\"chat_history\", \"tool_calls\"],\n", |
| ")\n", |
| "def AI_create_cypher_query(state: State, client: openai.Client) -> tuple[dict, State]:\n", |
| " \"\"\"AI step to create the cypher query.\"\"\"\n", |
| " messages = state[\"chat_history\"]\n", |
| " # Call the function\n", |
| " response = client.chat.completions.create(\n", |
| " model=\"gpt-4-turbo-preview\",\n", |
| " messages=messages,\n", |
| " tools=[run_cypher_query_tool_description],\n", |
| " tool_choice=\"auto\",\n", |
| " )\n", |
| " response_message = response.choices[0].message\n", |
| " new_state = state.append(chat_history=response_message.to_dict())\n", |
| " tool_calls = response_message.tool_calls\n", |
| " if tool_calls:\n", |
| " new_state = new_state.update(tool_calls=tool_calls)\n", |
| " # if there are no tool calls -- it means we didn't know what to do\n", |
| " return {\"ai_response\": response_message.content, \"usage\": response.usage.to_dict()}, new_state" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "e0fcc358", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:46:28.364022Z", |
| "start_time": "2024-05-27T16:46:28.347327Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "@action(\n", |
| " reads=[\"tool_calls\", \"chat_history\"],\n", |
| " writes=[\"tool_calls\", \"chat_history\"],\n", |
| ")\n", |
| "def tool_call(state: State, graph: falkordb.Graph) -> tuple[dict, State]:\n", |
| " \"\"\"Tool call step -- execute the query and append to chat history.\"\"\"\n", |
| " tool_calls = state.get(\"tool_calls\", [])\n", |
| " new_state = state\n", |
| " result = {\"tool_calls\": []}\n", |
| " for tool_call in tool_calls:\n", |
| " function_name = tool_call.function.name\n", |
| " assert function_name == \"run_cypher_query\"\n", |
| " function_args = json.loads(tool_call.function.arguments)\n", |
| " function_response = run_cypher_query(graph, function_args.get(\"query\"))\n", |
| " new_state = new_state.append(\n", |
| " chat_history={\n", |
| " \"tool_call_id\": tool_call.id,\n", |
| " \"role\": \"tool\",\n", |
| " \"name\": function_name,\n", |
| " \"content\": function_response,\n", |
| " }\n", |
| " )\n", |
| " result[\"tool_calls\"].append({\"tool_call_id\": tool_call.id, \"response\": function_response})\n", |
| " new_state = new_state.update(tool_calls=[])\n", |
| " return result, new_state" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "2f86c8a6", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:18.853441Z", |
| "start_time": "2024-05-27T16:47:18.847306Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "@action(\n", |
| " reads=[\"chat_history\"],\n", |
| " writes=[\"chat_history\"],\n", |
| ")\n", |
| "def AI_generate_response(state: State, client: openai.Client) -> tuple[dict, State]:\n", |
| " \"\"\"AI step to generate the response given the current chat history.\"\"\"\n", |
| " messages = state[\"chat_history\"]\n", |
| " response = client.chat.completions.create(\n", |
| " model=\"gpt-4-turbo-preview\",\n", |
| " messages=messages,\n", |
| " ) # get a new response from the model where it can see the function response\n", |
| " response_message = response.choices[0].message\n", |
| " new_state = state.append(chat_history=response_message.to_dict())\n", |
| " return {\"ai_response\": response_message.content, \"usage\": response.usage.to_dict()}, new_state" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f6723e3a", |
| "metadata": {}, |
| "source": [ |
| "## Define the application\n", |
| "This is where we define our application now" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "id": "d250a513", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:34.643728Z", |
| "start_time": "2024-05-27T16:47:34.515076Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "# define our clients / connections / IDs\n", |
| "openai_client = openai.OpenAI()\n", |
| "db_client = FalkorDB(host=\"localhost\", port=6379)\n", |
| "graph_name = \"UFC\"\n", |
| "application_run_id = str(uuid.uuid4())" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "e97c0268", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:43.535537Z", |
| "start_time": "2024-05-27T16:47:38.440468Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "# get the graph\n", |
| "graph = db_client.select_graph(graph_name)\n", |
| "# get schema\n", |
| "schema = graph_schema(graph)\n", |
| "# create a prompt from it\n", |
| "schema_prompt = schema_to_prompt(schema)\n", |
| "# set the initial chat history\n", |
| "base_messages = set_inital_chat_history(schema_prompt)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "id": "6b7a912d", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:49:23.058742Z", |
| "start_time": "2024-05-27T16:49:23.052222Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "tracker = LocalTrackingClient(\"ufc-falkor\")\n", |
| "# create graph\n", |
| "burr_application = (\n", |
| " ApplicationBuilder()\n", |
| " .with_actions( # define the actions\n", |
| " AI_create_cypher_query.bind(client=openai_client),\n", |
| " tool_call.bind(graph=graph),\n", |
| " AI_generate_response.bind(client=openai_client),\n", |
| " human_converse,\n", |
| " )\n", |
| " .with_transitions( # define the edges between the actions based on state conditions\n", |
| " (\"human_converse\", \"AI_create_cypher_query\", default),\n", |
| " (\"AI_create_cypher_query\", \"tool_call\", expr(\"len(tool_calls)>0\")),\n", |
| " (\"AI_create_cypher_query\", \"human_converse\", default),\n", |
| " (\"tool_call\", \"AI_generate_response\", default),\n", |
| " (\"AI_generate_response\", \"human_converse\", default),\n", |
| " )\n", |
| " .with_identifiers(app_id=application_run_id)\n", |
| " .with_state( # initial state\n", |
| " **{\"chat_history\": base_messages, \"tool_calls\": []},\n", |
| " )\n", |
| " .with_entrypoint(\"human_converse\")\n", |
| " .with_tracker(tracker)\n", |
| " .build()\n", |
| ")" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "id": "f195f59b", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:47:47.058569Z", |
| "start_time": "2024-05-27T16:47:46.295042Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "burr_application.visualize(include_conditions=True)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "dfbfb2ac", |
| "metadata": {}, |
| "source": [ |
| "## Run the application\n", |
| "Here we show how to do a simple loop stopping before `human_converse` each time to get user input before running the graph again.\n", |
| "\n", |
| "\n", |
| "### Viewing a trace of the this application in the Burr UI\n", |
| "Note: you can view the logs of the conversation in the Burr UI. \n", |
| "\n", |
| "To see that, in another terminal do:\n", |
| "\n", |
| "> burr\n", |
| "\n", |
| "You'll then have the UI running on [http://localhost:7241/](http://localhost:7241/).\n", |
| "\n", |
| "#### Using the Burr UI in google collab\n", |
| "To use the UI in google collab do the following:\n", |
| "\n", |
| "1. Run this in a cell\n", |
| "```python\n", |
| "from google.colab import output\n", |
| "output.serve_kernel_port_as_window(7241)\n", |
| "```\n", |
| "\n", |
| "2. Then start the burr UI:\n", |
| "```\n", |
| "!burr &\n", |
| "```\n", |
| "3. Click the link in (1) to open a new tab." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "id": "ddc4585c", |
| "metadata": { |
| "ExecuteTime": { |
| "end_time": "2024-05-27T16:48:59.800468Z", |
| "start_time": "2024-05-27T16:48:02.100003Z" |
| } |
| }, |
| "outputs": [], |
| "source": [ |
| "# run it\n", |
| "while True:\n", |
| " # this will ask for input:\n", |
| " question = input(\"What can I help you with?\\n\")\n", |
| " if question == \"exit\":\n", |
| " break\n", |
| " current_action, _, current_state = burr_application.run(\n", |
| " halt_before=[\"human_converse\"],\n", |
| " inputs={\"user_question\": question},\n", |
| " )\n", |
| " # we'll then see the AI response:\n", |
| " print(f\"AI: {current_state['chat_history'][-1]['content']}\\n\")\n", |
| "current_state" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "e7ea83f2", |
| "metadata": {}, |
| "source": [ |
| "With Burr we can continue where we left off easily!\n", |
| "\n", |
| "So why not run the conversation through some more?" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "id": "d3961d9f", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# run it\n", |
| "while True:\n", |
| " # this will ask for input:\n", |
| " question = input(\"What can I help you with?\\n\")\n", |
| " if question == \"exit\":\n", |
| " break\n", |
| " current_action, _, current_state = burr_application.run(\n", |
| " halt_before=[\"human_converse\"],\n", |
| " inputs={\"user_question\": question},\n", |
| " )\n", |
| " # we'll then see the AI response:\n", |
| " print(f\"AI: {current_state['chat_history'][-1]['content']}\\n\")\n", |
| "current_state" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "9d0e91d4", |
| "metadata": {}, |
| "outputs": [], |
| "source": [] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.10.13" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |