blob: c13a167a9d2d2b19e1383d8f394fda26a8bd420f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import importlib
import json
from typing import List, Literal
import pydantic
from fastapi import APIRouter
from starlette.responses import StreamingResponse
from burr.core import Application, ApplicationBuilder
from burr.tracking import LocalTrackingClient
"""This file represents a simple chatbot API backed with Burr.
We manage an application, write to it with post endpoints, and read with
get/ endpoints.
This demonstrates how you can build interactive web applications with Burr!
"""
# We're doing dynamic import cause this lives within examples/ (and that module has dashes)
# navigate to the examples directory to read more about this!
chat_application = importlib.import_module(
"burr.examples.streaming-fastapi.application"
) # noqa: F401
# the app is commented out as we include the router.
# app = FastAPI()
router = APIRouter()
graph = chat_application.graph
try:
from opentelemetry.instrumentation.openai import OpenAIInstrumentor
OpenAIInstrumentor().instrument()
opentelemetry_available = True
except ImportError:
opentelemetry_available = False
class ChatItem(pydantic.BaseModel):
"""Pydantic model for a chat item. This is used to render the chat history."""
content: str
type: Literal["image", "text", "code", "error"]
role: Literal["user", "assistant"]
@functools.lru_cache(maxsize=128)
def _get_application(project_id: str, app_id: str) -> Application:
"""Quick tool to get the application -- caches it"""
tracker = LocalTrackingClient(project=project_id, storage_dir="~/.burr")
return (
ApplicationBuilder()
.with_graph(graph)
# initializes from the tracking log if it does not already exist
.initialize_from(
tracker,
resume_at_next_action=False, # always resume from entrypoint in the case of failure
default_state={"chat_history": []},
default_entrypoint="prompt",
)
.with_tracker(tracker, use_otel_tracing=opentelemetry_available)
.with_identifiers(app_id=app_id)
.build()
)
class PromptInput(pydantic.BaseModel):
prompt: str
@router.post("/response/{project_id}/{app_id}", response_class=StreamingResponse)
async def chat_response(project_id: str, app_id: str, prompt: PromptInput) -> StreamingResponse:
"""Chat response endpoint. User passes in a prompt and the system returns the
full chat history, so its easier to render.
:param project_id: Project ID to run
:param app_id: Application ID to run
:param prompt: Prompt to send to the chatbot
:return:
"""
burr_app = _get_application(project_id, app_id)
chat_history = burr_app.state.get("chat_history", [])
action, streaming_container = await burr_app.astream_result(
halt_after=chat_application.TERMINAL_ACTIONS, inputs=dict(prompt=prompt.prompt)
)
async def sse_generator():
"""This is a generator that yields Server-Sent Events (SSE) to the client
It is necessary to yield them in a special format to ensure the client can
access them streaming. We type them (using our own simple typing system) then
parse on the client side. Unfortunately, typing these in FastAPI is not feasible."""
yield f"data: {json.dumps({'type': 'chat_history', 'value': chat_history})}\n\n"
async for item in streaming_container:
yield f"data: {json.dumps({'type': 'delta', 'value': item['delta']})} \n\n"
return StreamingResponse(sse_generator())
@router.get("/history/{project_id}/{app_id}", response_model=List[ChatItem])
def chat_history(project_id: str, app_id: str) -> List[ChatItem]:
"""Endpoint to get chat history. Gets the application and returns the chat history from state.
:param project_id: Project ID
:param app_id: App ID.
:return: The list of chat items in the state
"""
chat_app = _get_application(project_id, app_id)
state = chat_app.state
return state.get("chat_history", [])
@router.post("/create/{project_id}/{app_id}", response_model=str)
async def create_new_application(project_id: str, app_id: str) -> str:
"""Endpoint to create a new application -- used by the FE when
the user types in a new App ID
:param project_id: Project ID
:param app_id: App ID
:return: The app ID
"""
# side-effect of this persists it -- see the application function for details
_get_application(app_id=app_id, project_id=project_id)
return app_id # just return it for now
# # comment this back in for a standalone chatbot API
# import fastapi
#
# app = fastapi.FastAPI()
# app.include_router(router, prefix="/api/v0/chatbot")