| class ProxyClass: |
| def __init__(self, wrapped_object, tracer, name=""): |
| self._wrapped_object = wrapped_object |
| self._tracer: TracerFactory = tracer |
| self._active_spans = {} |
| self._name = name |
| |
| def __getattr__(self, name): |
| if name == "_wrapped_object": |
| return self._wrapped_object |
| elif name == "_tracer": |
| return self._tracer |
| elif name == "_active_spans": |
| return self._active_spans |
| elif name == "_name": |
| return self._name |
| attr = getattr(self._wrapped_object, name) |
| |
| if callable(attr): |
| |
| def hooked(*args, **kwargs): |
| print(f"Calling method: {name}") |
| print(f"Arguments: {args}") |
| print(f"Keyword Arguments: {kwargs}") |
| context_manager: ActionSpanTracer = self._tracer(f"{self._name}.{name}") |
| context_manager.__enter__() |
| self._active_spans[name] = context_manager |
| result = attr(*args, **kwargs) |
| context_manager = self._active_spans.pop(name) |
| context_manager.__exit__(None, None, None) |
| print(f"Result: {result}") |
| return result |
| |
| return hooked |
| elif isinstance(attr, object): |
| return ProxyClass(attr, self._tracer, name=f"{self._name}.{name}") |
| else: |
| return attr |
| |
| |
| import copy |
| import os |
| from typing import Any, List, Optional |
| |
| import openai |
| |
| from burr.core import Action, Application, ApplicationBuilder, State, default, graph, when |
| from burr.core.action import action |
| from burr.lifecycle import LifecycleAdapter, PostRunStepHook, PreRunStepHook |
| from burr.tracking import LocalTrackingClient |
| from burr.visibility import ActionSpanTracer, TracerFactory |
| |
| |
| class ProxyClassHook(PreRunStepHook, PostRunStepHook): |
| def __init__(self): |
| pass |
| # self.cache = {} |
| |
| def pre_run_step( |
| self, |
| action: Action, |
| sequence_id: int, |
| inputs: dict[str, Any], |
| _ProxyClassHook__tracer: TracerFactory, |
| **future_kwargs, |
| ): |
| if "client" in inputs: |
| if isinstance(inputs["client"], openai.OpenAI): |
| inputs["client"] = ProxyClass( |
| inputs["client"], _ProxyClassHook__tracer(action.name, sequence_id) |
| ) |
| elif isinstance(inputs["client"], ProxyClass): |
| inputs["client"]._tracer = _ProxyClassHook__tracer(action.name, sequence_id) |
| # self.cache[action.name] = ProxyClass(inputs['client']) |
| |
| def post_run_step(self, *, state: "State", action: Action, result: dict, **future_kwargs): |
| print(f"Action: {action.name}") |
| |
| |
| MODES = { |
| "answer_question": "text", |
| "generate_image": "image", |
| "generate_code": "code", |
| "unknown": "text", |
| } |
| |
| |
| @action(reads=[], writes=["chat_history", "prompt"]) |
| def process_prompt(state: State, prompt: str) -> State: |
| result = {"chat_item": {"role": "user", "content": prompt, "type": "text"}} |
| return ( |
| state.wipe(keep=["prompt", "chat_history"]) |
| .append(chat_history=result["chat_item"]) |
| .update(prompt=prompt) |
| ) |
| |
| |
| @action(reads=[], writes=["has_openai_key"]) |
| def check_openai_key(state: State) -> State: |
| result = {"has_openai_key": "OPENAI_API_KEY" in os.environ} |
| return state.update(**result) |
| |
| |
| @action(reads=["prompt"], writes=["safe"]) |
| def check_safety(state: State) -> State: |
| result = {"safe": "unsafe" not in state["prompt"]} # quick hack to demonstrate |
| return state.update(safe=result["safe"]) |
| |
| |
| # def _get_openai_client(): |
| # return openai.Client() |
| |
| |
| @action(reads=["prompt"], writes=["mode"]) |
| def choose_mode(state: State, client: openai.OpenAI) -> State: |
| prompt = ( |
| f"You are a chatbot. You've been prompted this: {state['prompt']}. " |
| f"You have the capability of responding in the following modes: {', '.join(MODES)}. " |
| "Please respond with *only* a single word representing the mode that most accurately " |
| "corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', " |
| "the mode would be 'generate_image'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'." |
| "If none of these modes apply, please respond with 'unknown'." |
| ) |
| |
| result = client.chat.completions.create( |
| model="gpt-4", |
| messages=[ |
| {"role": "system", "content": "You are a helpful assistant"}, |
| {"role": "user", "content": prompt}, |
| ], |
| ) |
| content = result.choices[0].message.content |
| mode = content.lower() |
| if mode not in MODES: |
| mode = "unknown" |
| result = {"mode": mode} |
| return state.update(**result) |
| |
| |
| @action(reads=["prompt", "chat_history"], writes=["response"]) |
| def prompt_for_more(state: State) -> State: |
| result = { |
| "response": { |
| "content": "None of the response modes I support apply to your question. Please clarify?", |
| "type": "text", |
| "role": "assistant", |
| } |
| } |
| return state.update(**result) |
| |
| |
| @action(reads=["prompt", "chat_history", "mode"], writes=["response"]) |
| def chat_response( |
| state: State, |
| prepend_prompt: str, |
| client: openai.OpenAI, |
| display_type: str = "text", |
| model: str = "gpt-3.5-turbo", |
| ) -> State: |
| chat_history = copy.deepcopy(state["chat_history"]) |
| chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}" |
| chat_history_api_format = [ |
| { |
| "role": chat["role"], |
| "content": chat["content"], |
| } |
| for chat in chat_history |
| ] |
| # client = _get_openai_client() |
| result = client.chat.completions.create( |
| model=model, |
| messages=chat_history_api_format, |
| ) |
| response = result.choices[0].message.content |
| result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} |
| return state.update(**result) |
| |
| |
| @action(reads=["prompt", "chat_history", "mode"], writes=["response"]) |
| def image_response(state: State, client: openai.OpenAI, model: str = "dall-e-2") -> State: |
| """Generates an image response to the prompt. Optional save function to save the image to a URL.""" |
| # client = _get_openai_client() |
| result = client.images.generate( |
| model=model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1 |
| ) |
| response = result.data[0].url |
| result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} |
| return state.update(**result) |
| |
| |
| @action(reads=["response", "mode", "safe", "has_openai_key"], writes=["chat_history"]) |
| def response(state: State) -> State: |
| if not state["has_openai_key"]: |
| result = { |
| "chat_item": { |
| "role": "assistant", |
| "content": "You have not set an API key for [OpenAI](https://www.openai.com). Do this " |
| "by setting the environment variable `OPENAI_API_KEY` to your key. " |
| "You can get a key at [OpenAI](https://platform.openai.com). " |
| "You can still look at chat history/examples.", |
| "type": "error", |
| } |
| } |
| elif not state["safe"]: |
| result = { |
| "chat_item": { |
| "role": "assistant", |
| "content": "I'm sorry, I can't respond to that.", |
| "type": "error", |
| } |
| } |
| else: |
| result = {"chat_item": state["response"]} |
| return state.append(chat_history=result["chat_item"]) |
| |
| |
| base_graph = ( |
| graph.GraphBuilder() |
| .with_actions( |
| prompt=process_prompt, |
| check_openai_key=check_openai_key, |
| check_safety=check_safety, |
| decide_mode=choose_mode, |
| generate_image=image_response, |
| generate_code=chat_response.bind( |
| prepend_prompt="Please respond with *only* code and no other text (at all) to the following:", |
| ), |
| answer_question=chat_response.bind( |
| prepend_prompt="Please answer the following question:", |
| ), |
| prompt_for_more=prompt_for_more, |
| response=response, |
| ) |
| .with_transitions( |
| ("prompt", "check_openai_key", default), |
| ("check_openai_key", "check_safety", when(has_openai_key=True)), |
| ("check_openai_key", "response", default), |
| ("check_safety", "decide_mode", when(safe=True)), |
| ("check_safety", "response", default), |
| ("decide_mode", "generate_image", when(mode="generate_image")), |
| ("decide_mode", "generate_code", when(mode="generate_code")), |
| ("decide_mode", "answer_question", when(mode="answer_question")), |
| ("decide_mode", "prompt_for_more", default), |
| ( |
| ["generate_image", "answer_question", "generate_code", "prompt_for_more"], |
| "response", |
| ), |
| ("response", "prompt", default), |
| ) |
| .build() |
| ) |
| |
| |
| def base_application( |
| hooks: List[LifecycleAdapter], |
| app_id: str, |
| storage_dir: str, |
| project_id: str, |
| ): |
| if hooks is None: |
| hooks = [] |
| # we're initializing above so we can load from this as well |
| # we could also use `with_tracker("local", project=project_id, params={"storage_dir": storage_dir})` |
| tracker = LocalTrackingClient(project=project_id, storage_dir=storage_dir) |
| return ( |
| ApplicationBuilder() |
| .with_graph(base_graph) |
| # initializes from the tracking log if it does not already exist |
| .initialize_from( |
| tracker, |
| resume_at_next_action=False, # always resume from entrypoint in the case of failure |
| default_state={"chat_history": []}, |
| default_entrypoint="prompt", |
| ) |
| .with_hooks(*hooks) |
| .with_tracker(tracker) |
| .with_identifiers(app_id=app_id) |
| .build() |
| ) |
| |
| |
| def application( |
| app_id: Optional[str] = None, |
| project_id: str = "test_proxy", |
| storage_dir: Optional[str] = "~/.burr", |
| hooks: Optional[List[LifecycleAdapter]] = None, |
| ) -> Application: |
| return base_application(hooks, app_id, storage_dir, project_id=project_id) |
| |
| |
| if __name__ == "__main__": |
| app = application(hooks=[ProxyClassHook()]) |
| # app.visualize( |
| # output_file_path="statemachine", include_conditions=False, view=True, format="png" |
| # ) |
| print( |
| app.run( |
| halt_after=["response"], |
| inputs={"prompt": "Who was Aaron Burr, sir?", "client": openai.OpenAI()}, |
| ) |
| ) |