blob: a36bcad551d2b7aced79c9a8ce51b3db0ad39e81 [file] [log] [blame]
"""
Module that contains code to house state for an agent. The dialog
right now is hardcoded at the bottom of this file.
"""
import json
import logging
import sys
import functions
import openai
from tenacity import retry, stop_after_attempt, wait_random_exponential
from termcolor import colored
logger = logging.getLogger(__name__)
GPT_MODEL = "gpt-3.5-turbo-0613"
# TODO: create a programmatic way to generate this from the function definitions themselves.
arxiv_functions = [
{
"name": "get_articles",
"description": """Use this function to get academic papers from arXiv to answer user questions.""",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": """
User query in JSON. Responses should be summarized and should include the article URL reference
""",
}
},
"required": ["query"],
},
},
{
"name": "read_article_and_summarize",
"description": """Use this function to read whole papers and provide a summary for users.
You should NEVER call this function before get_articles has been called in the conversation.""",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": """
Description of the article in plain text based on the user's query
""",
}
},
"required": ["query"],
},
},
]
class Conversation:
"""Class that houses the conversation history and displays it in a nice way."""
def __init__(self):
self.conversation_history = []
def add_message(self, role, content):
message = {"role": role, "content": content}
self.conversation_history.append(message)
def display_conversation(self):
role_to_color = {
"system": "red",
"user": "green",
"assistant": "blue",
"function": "magenta",
}
for message in self.conversation_history:
print(
colored(
f"{message['role']}: {message['content']}\n\n",
role_to_color[message["role"]],
)
)
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, openai_gpt_model: str, functions=None, temperature: int = 0):
"""This function makes a ChatCompletion API call with the option of adding functions"""
kwargs = dict(model=openai_gpt_model, messages=messages, temperature=temperature)
if functions:
kwargs["functions"] = functions
try:
response = openai.ChatCompletion.create(**kwargs)
return response
except Exception as e:
logger.error("Unable to generate ChatCompletion response")
logger.error(f"Exception: {e}")
return e
def chat_completion_with_function_execution(messages, functions=None):
"""This function makes a ChatCompletion API call with the option of adding functions"""
response = chat_completion_request(messages, openai_gpt_model=GPT_MODEL, functions=functions)
full_message = response["choices"][0]
if full_message["finish_reason"] == "function_call":
logger.info("Function generation requested, calling function")
return call_arxiv_function(messages, full_message)
else:
logger.info("Function not required, responding to user")
return response["choices"][0]["message"]["content"]
def call_arxiv_function(messages, full_message):
"""Function calling function which executes function calls when the model believes it is necessary.
Currently extended by adding clauses to this if statement.
This code is pulled directly from the cookbook -- it is not well structured.
"""
if full_message["message"]["function_call"]["name"] == "get_articles":
try:
parsed_output = json.loads(full_message["message"]["function_call"]["arguments"])
logger.info("Getting search results")
results = functions.get_articles(parsed_output["query"])
except Exception as e:
logger.info(parsed_output)
logger.info("Function execution failed")
logger.info(f"Error message: {e}")
raise e
messages.append(
{
"role": "function",
"name": full_message["message"]["function_call"]["name"],
"content": str(results),
}
)
try:
logger.info("Got search results, summarizing content")
response = chat_completion_request(messages, openai_gpt_model=GPT_MODEL)
return response
except Exception as e:
logger.error(type(e))
raise Exception("Function chat request failed")
elif full_message["message"]["function_call"]["name"] == "read_article_and_summarize":
parsed_output = json.loads(full_message["message"]["function_call"]["arguments"])
logger.info("Finding and reading paper")
summary = functions.read_article_and_summarize(parsed_output["query"])
return summary
else:
raise Exception("Function does not exist and cannot be called")
if __name__ == "__main__":
"""Here is is some example dialog interactions."""
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
paper_system_message = """You are arXivGPT, a helpful assistant pulls academic papers to answer user questions.
You summarize the papers clearly so the customer can decide which to read to answer their question.
You always provide the article_url and title so the user can understand the name of the paper and click through to access it.
Begin!"""
paper_conversation = Conversation()
paper_conversation.add_message("system", paper_system_message)
# Add a user message
paper_conversation.add_message("user", "Hi, how does PPO reinforcement learning work?")
chat_response = chat_completion_with_function_execution(
paper_conversation.conversation_history, functions=arxiv_functions
)
if isinstance(chat_response, openai.openai_object.OpenAIObject):
assistant_message = chat_response["choices"][0]["message"]["content"]
else:
assistant_message = chat_response
paper_conversation.add_message("assistant", assistant_message)
print(assistant_message)
paper_conversation.add_message(
"user",
"Can you read the PPO sequence generation paper for me and give me a summary",
)
updated_response = chat_completion_with_function_execution(
paper_conversation.conversation_history, functions=arxiv_functions
)
if isinstance(updated_response, openai.openai_object.OpenAIObject):
assistant_message = updated_response["choices"][0]["message"]["content"]
else:
assistant_message = updated_response
paper_conversation.add_message("assistant", assistant_message)
print(updated_response)
paper_conversation.add_message(
"user",
"Can you read the A2C PPO paper for me and give me a summary",
)
updated_response = chat_completion_with_function_execution(
paper_conversation.conversation_history, functions=arxiv_functions
)
if isinstance(updated_response, openai.openai_object.OpenAIObject):
assistant_message = updated_response["choices"][0]["message"]["content"]
else:
assistant_message = updated_response
paper_conversation.add_message("assistant", updated_response)
print(updated_response)
print("Conversation history:")
paper_conversation.display_conversation()