blob: 26a9f03f33199f54c21658456eb2e88d028e1d17 [file] [log] [blame]
"""Module to house functions for an LLM agent to use."""
import logging
import arxiv_articles
import pandas as pd
import summarize_text
from hamilton import base, driver
from hamilton.execution.executors import MultiThreadingExecutor
logger = logging.getLogger(__name__)
def get_articles(query: str) -> pd.DataFrame:
"""Use this function to get academic papers from arXiv to answer user questions.
:param query: User query in JSON. Responses should be summarized and should include the article URL reference
:return: List of dictionaries with title, summary, article_url, pdf_url
"""
dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_modules(arxiv_articles)
.with_config({"mock_openai": True})
.with_remote_executor(MultiThreadingExecutor(max_tasks=10))
.with_adapter(base.SimplePythonDataFrameGraphAdapter())
.build()
)
inputs = {
"embedding_model_name": "text-embedding-ada-002",
"max_arxiv_results": 5,
"article_query": query,
"data_dir": "./data",
"library_file_path": "./data/arxiv_library.csv",
}
dr.display_all_functions("./get_articles", {"format": "png"})
return dr.execute(["arxiv_result_df", "save_arxiv_result_df"], inputs=inputs)
def read_article_and_summarize(query: str) -> str:
"""Use this function to read whole papers and provide a summary for users.
You should NEVER call this function before get_articles has been called in the conversation.
:param query: Description of the article in plain text based on the user's query.
:return: Summarized text of the article given the query.
"""
dr = driver.Driver({}, summarize_text, adapter=base.DefaultAdapter())
inputs = {
"embedding_model_name": "text-embedding-ada-002",
"openai_gpt_model": "gpt-3.5-turbo-0613",
"user_query": query,
"top_n": 1,
"max_token_length": 1500,
"library_file_path": "./data/arxiv_library.csv",
}
dr.display_all_functions("./read_article_and_summarize", {"format": "png"})
result = dr.execute(["summarize_text"], inputs=inputs)
return result["summarize_text"]
if __name__ == "__main__":
"""Code to quickly integration test."""
from hamilton import log_setup
log_setup.setup_logging(log_level=log_setup.LOG_LEVELS["DEBUG"])
_df = get_articles("ppo reinforcement learning")
print(_df)
_summary = read_article_and_summarize("PPO reinforcement learning sequence generation")
print(_summary)