blob: 50ec5db42c266d2eef6f07ac6201950491312fcf [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import logging
import uuid
import application as chatbot_application
import streamlit as st
from hamilton.log_setup import setup_logging
import burr.core
from burr.core.action import AsyncStreamingResultContainer
setup_logging(logging.INFO)
def render_chat_message(chat_item: dict):
content = chat_item["content"]
role = chat_item["role"]
with st.chat_message(role):
st.write(content)
async def render_streaming_chat_message(stream: AsyncStreamingResultContainer):
buffer = ""
with st.chat_message("assistant"):
# This is very ugly as streamlit does not support async generators
# Thus we have to ignore the benefit of writing the delta and instead write *everything*
with st.empty():
async for item in stream:
buffer += item["delta"]
st.write(buffer)
def initialize_app() -> burr.core.Application:
if "burr_app" not in st.session_state:
st.session_state.burr_app = chatbot_application.application(
app_id=f"chat_streaming:{str(uuid.uuid4())[0:6]}"
)
return st.session_state.burr_app
async def main():
st.title("Streaming chatbot with Burr")
app = initialize_app()
prompt = st.chat_input("Ask me a question!", key="chat_input")
for chat_message in app.state.get("chat_history", []):
render_chat_message(chat_message)
if prompt:
render_chat_message({"role": "user", "content": prompt, "type": "text"})
with st.spinner(text="Waiting for response..."):
action, streaming_container = await app.astream_result(
halt_after=chatbot_application.TERMINAL_ACTIONS, inputs={"prompt": prompt}
)
await render_streaming_chat_message(streaming_container)
if __name__ == "__main__":
asyncio.run(main())