blob: e451ea7e761eb8fcd2c248ebc4e2798a4072d94a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
from copy import deepcopy
from typing import Annotated
import google.generativeai as genai
import instructor
from dotenv import load_dotenv
from pydantic import AfterValidator, BaseModel, Field
from burr.core import Application, ApplicationBuilder, State, expr
from burr.core.action import action
from burr.tracking import LocalTrackingClient
load_dotenv() # load the GOOGLE_API_KEY from your .env file
ask_gemini = instructor.from_gemini(
client=genai.GenerativeModel(model_name="gemini-1.5-flash-latest")
)
MIN_WORDS = 3
MAX_WORDS = 5
NUM_REQUIRED_TOPICS = 3
MIN_SUBTOPICS = 2
MAX_SUBTOPICS = 4
MIN_CONCEPTS = 2
MAX_CONCEPTS = 4
ATTEMPTS = 3
TEMPERATURE = 0.3
class Subtopic(BaseModel):
name: Annotated[str, AfterValidator(str.title)]
# pydantic will ensure that we have at least MIN_CONCEPTS and at most MAX_CONCEPTS concepts
# if we don't, the LLM will be asked to generate the concepts again
concepts: Annotated[
list[str], AfterValidator(lambda concepts: [x.title() for x in concepts])
] = Field(
description=f"{MIN_CONCEPTS}-{MAX_CONCEPTS} concepts covered in the subtopic.",
min_length=MIN_CONCEPTS,
max_length=MAX_CONCEPTS,
)
class Topic(BaseModel):
name: Annotated[str, AfterValidator(str.title)]
# pydantic will ensure that we have at least MIN_SUBTOPICS and at most MAX_SUBTOPICS subtopics
# if we don't, the LLM will be asked to generate the subtopics again
subtopics: list[Subtopic] = Field(
description=f"{MIN_SUBTOPICS}-{MAX_SUBTOPICS} ordered subtopics with concepts.",
min_length=MIN_SUBTOPICS,
max_length=MAX_SUBTOPICS,
)
def __str__(self) -> str:
topic_str = f"TOPIC: {self.name}\n"
for i, subtopic in enumerate(self.subtopics):
topic_str += f"SUBTOPIC {i + 1}: {subtopic.name}\n"
for j, concept in enumerate(subtopic.concepts):
topic_str += f"CONCEPT {j + 1}: {concept}\n"
return topic_str
def topics_system_template(
num_required_topics: int = NUM_REQUIRED_TOPICS,
min_words: int = MIN_WORDS,
max_words: int = MAX_WORDS,
min_subtopics: int = MIN_SUBTOPICS,
max_subtopics: int = MAX_SUBTOPICS,
min_concepts: int = MIN_CONCEPTS,
max_concepts: int = MAX_CONCEPTS,
) -> str:
return f"""\
"You are a world class course instructor."
You'll be given a course outline and you have to generate {num_required_topics} topics.
For each topic:
1. Generate a {min_words}-{max_words} word topic name that encapsulates the description.
2. Generate {min_subtopics}-{max_subtopics} subtopics for the topic. Also {min_words}-{max_words} words each.
For each subtopic:
Generate {min_concepts}-{max_concepts} concepts. Also {min_words}-{max_words} words each. The concepts should be related to the subtopic.
Think of concepts as the smallest unit of knowledge that can be taught from the subtopic. And add a verb to the concept to make it actionable.
For example:
"Calculate Derivatives" instead of "Derivatives".
"Identify Finite Sets" instead of "Finite Sets".
"Find the y-intercept" instead of "y-intercept".
The subtopics and concepts should be in the correct order.\
"""
def creation_template(
outline: str,
topics_so_far: list[Topic],
num_required_topics: int = NUM_REQUIRED_TOPICS,
) -> str:
prompt_strs = [
f"<outline>\n{outline}\n</outline>",
f"<num_required_topics>\n{num_required_topics}\n</num_required_topics>",
]
topics_so_far_str = "<topics_so_far>\n"
if topics_so_far:
for i, topic in enumerate(topics_so_far):
topics_so_far_str += f"{i + 1}/{num_required_topics}\n{topic}\n"
topics_so_far_str = topics_so_far_str.strip() + "\n</topics_so_far>"
prompt_strs.append(topics_so_far_str)
prompt_strs.append("Generate the next topic.")
return "\n\n".join(prompt_strs)
@action(reads=[], writes=["outline", "num_required_topics", "chat_history"])
def setup(state: State, outline: str, num_required_topics: int) -> State:
"""
Initializes the state of the application with the course outline and the number of topics to be generated.
"""
return state.update(
outline=outline,
num_required_topics=num_required_topics,
chat_history=[
{
"role": "system",
"content": topics_system_template(num_required_topics=num_required_topics),
}
],
)
@action(
reads=[
"chat_history",
"outline",
"num_required_topics",
"topics_so_far",
"topic_feedback",
],
writes=["generated_topic", "chat_history"],
)
def creator(
state: State,
attempts: int = ATTEMPTS,
) -> tuple[dict, State]:
"""
Generates a topic based on the outline and topics generated so far using the create function from the Instructor library.
https://python.useinstructor.com/#using-gemini
Parameters:
- state (State): The current state of the application.
- attempts (int, optional): The number of times the model should try to generate a response. 3 by default. If the model fails to generate a response after the specified number of attempts, the function will return with 'generated_topic' set to None.
Returns:
- tuple[dict, State]: A tuple containing the generated topic and the updated state.
Instructor Parameters Used:
- response_model: The Pydantic model or the type that we want the response to be. In our case, it's the `Topic` model.
- messages: The chat history that we want to pass to the model. This could have the prompt for generating the next topic, or the previously generated topic with user feedback.
- max_retries: The number of times we want to retry if the model fails to generate a response. 1 means just the original attempt, 2 means the original attempt and one retry.
- temperature: The temperature of the model. A float between 0 and 1. Lower values give more deterministic results. 0.3 by default.
"""
num_required_topics = state.get("num_required_topics", NUM_REQUIRED_TOPICS)
topics_so_far = state.get("topics_so_far", [])
topic_feedback = state.get("topic_feedback", "")
messages = deepcopy(state["chat_history"])
if topic_feedback:
# this means that we had some user feedback on the previously generated topic
# so instead of prompting for a new topic, we will add the feedback as a user message
# the previously generated topic is already an assistant message at the end of the chat history
user_message = f"User feedback: {topic_feedback.strip()}"
else:
# either no previous topic or no feedback on the previous topic
# so we will prompt for a new topic based on the outline and topics generated so far
user_message = creation_template(
outline=state["outline"],
topics_so_far=topics_so_far,
num_required_topics=num_required_topics,
)
messages.append({"role": "user", "content": user_message})
# these are the parameters that we will pass to the create function from the Instructor library
create_kwargs = dict(
response_model=Topic,
messages=messages,
max_retries=attempts,
temperature=TEMPERATURE,
)
try:
res = ask_gemini.create(**create_kwargs) # type: ignore
except Exception as e:
print(f"ERROR in creator: {e}")
# if the model fails to generate a response even after multiple attempts, we will return None
# creator will be called again until a valid response is generated
return {"generated_topic": None}, state.update(generated_topic=None)
# add the user message and the generated topic to the chat history
return {"generated_topic": res}, state.update(generated_topic=res).append(
chat_history={"role": "user", "content": user_message}
).append(chat_history={"role": "assistant", "content": str(res)})
@action(
reads=["generated_topic", "num_required_topics", "topics_so_far"],
writes=["topic_feedback"],
)
def get_topic_feedback(state: State) -> tuple[dict, State]:
"""
Asks the user for feedback on the generated topic.
If the user is happy with the generated topic, they can leave the feedback empty.
This feedback will be added as a user message to the chat history.
"""
time.sleep(2)
num_required_topics = state.get("num_required_topics", NUM_REQUIRED_TOPICS)
num_topics_so_far = len(state.get("topics_so_far", []))
generated_topic = state["generated_topic"]
generated_topic_str = f"{num_topics_so_far + 1}/{num_required_topics}\n{generated_topic}"
feedback = (
input(
f"Give your feedback on this generated topic. Leave empty if it's perfect.\n\n{generated_topic_str}"
)
or None
)
return {"topic_feedback": feedback}, state.update(topic_feedback=feedback)
@action(reads=["generated_topic"], writes=["topics_so_far"])
def update_topics_so_far(state: State) -> State:
"""
This is called if there is no feedback on the generated topic.
We assume that the user is happy with the generated topic and we add it to the list of topics generated so far.
This list will be incorporated in the prompt for the next topic.
"""
return state.append(topics_so_far=state["generated_topic"])
@action(reads=["topics_so_far"], writes=[])
def terminal(state: State) -> tuple[dict, State]:
"""
This is the terminal state of the application.
We will reach this state when the required number of topics have been generated.
"""
return {"topics_so_far": state["topics_so_far"]}, state
def application(
app_id: str | None = None,
username: str | None = None,
project: str = "topics_creator",
) -> Application:
tracker = LocalTrackingClient(project=project)
builder = (
ApplicationBuilder()
.with_actions(setup, creator, get_topic_feedback, update_topics_so_far, terminal)
.with_transitions(
("setup", "creator"),
# go to creator again if the generated topic is None
("creator", "creator", expr("generated_topic is None")), # type: ignore
# get feedback if the generated topic is not None
("creator", "get_topic_feedback"),
# if the feedback is empty, add the generated topic to the list of topics generated so far
(
"get_topic_feedback",
"update_topics_so_far",
expr("topic_feedback is None"), # type: ignore
),
# if the feedback is not empty, go back to creator and add the feedback as a user message
("get_topic_feedback", "creator"),
# if the number of topics generated so far is equal to the required number of topics, go to terminal
(
"update_topics_so_far",
"terminal",
expr("len(topics_so_far) == num_required_topics"), # type: ignore
),
# otherwise, go back to creator after updating the topics generated so far
("update_topics_so_far", "creator"),
)
.with_tracker("local", project=project)
.with_identifiers(app_id=app_id, partition_key=username) # type: ignore
.initialize_from(
tracker,
resume_at_next_action=True,
default_entrypoint="setup",
default_state={},
)
)
return builder.build()
if __name__ == "__main__":
app = application()
app.visualize(
output_file_path="statemachine",
include_conditions=True,
include_state=False,
format="png",
)