WIP: prototyping proxy object spans So this prototypes something to connect. Things to discuss creating the spans: 1. when do we proxy the object? before and after each step? (can't do it with step because post doesn't get inputs... but could?) 2. do we do it once at the beginning (would need to expose inputs)? But then how do we provide the tracer? with before and after hooks? Things for the spans: 1. need to log what can be pulled from it -- where? how? 2. how does this then display in the UI? Otherwise technically this should work for any "object". We'd want to internally __burr_ the variables.
diff --git a/burr/core/application.py b/burr/core/application.py index 98091af..9930367 100644 --- a/burr/core/application.py +++ b/burr/core/application.py
@@ -525,6 +525,7 @@ sequence_id=self.sequence_id, app_id=self._uid, partition_key=self._partition_key, + _ProxyClassHook__tracer=self.dependency_factory["__tracer"], ) exc = None result = None
diff --git a/examples/client_instrument/application.py b/examples/client_instrument/application.py new file mode 100644 index 0000000..43681ae --- /dev/null +++ b/examples/client_instrument/application.py
@@ -0,0 +1,301 @@ +class ProxyClass: + def __init__(self, wrapped_object, tracer, name=""): + self._wrapped_object = wrapped_object + self._tracer: TracerFactory = tracer + self._active_spans = {} + self._name = name + + def __getattr__(self, name): + if name == "_wrapped_object": + return self._wrapped_object + elif name == "_tracer": + return self._tracer + elif name == "_active_spans": + return self._active_spans + elif name == "_name": + return self._name + attr = getattr(self._wrapped_object, name) + + if callable(attr): + + def hooked(*args, **kwargs): + print(f"Calling method: {name}") + print(f"Arguments: {args}") + print(f"Keyword Arguments: {kwargs}") + context_manager: ActionSpanTracer = self._tracer(f"{self._name}.{name}") + context_manager.__enter__() + self._active_spans[name] = context_manager + result = attr(*args, **kwargs) + context_manager = self._active_spans.pop(name) + context_manager.__exit__(None, None, None) + print(f"Result: {result}") + return result + + return hooked + elif isinstance(attr, object): + return ProxyClass(attr, self._tracer, name=f"{self._name}.{name}") + else: + return attr + + +import copy +import os +from typing import Any, List, Optional + +import openai + +from burr.core import Action, Application, ApplicationBuilder, State, default, graph, when +from burr.core.action import action +from burr.lifecycle import LifecycleAdapter, PostRunStepHook, PreRunStepHook +from burr.tracking import LocalTrackingClient +from burr.visibility import ActionSpanTracer, TracerFactory + + +class ProxyClassHook(PreRunStepHook, PostRunStepHook): + def __init__(self): + pass + # self.cache = {} + + def pre_run_step( + self, + action: Action, + sequence_id: int, + inputs: dict[str, Any], + _ProxyClassHook__tracer: TracerFactory, + **future_kwargs, + ): + if "client" in inputs: + if isinstance(inputs["client"], openai.OpenAI): + inputs["client"] = ProxyClass( + inputs["client"], _ProxyClassHook__tracer(action.name, sequence_id) + ) + elif isinstance(inputs["client"], ProxyClass): + inputs["client"]._tracer = _ProxyClassHook__tracer(action.name, sequence_id) + # self.cache[action.name] = ProxyClass(inputs['client']) + + def post_run_step(self, *, state: "State", action: Action, result: dict, **future_kwargs): + print(f"Action: {action.name}") + + +MODES = { + "answer_question": "text", + "generate_image": "image", + "generate_code": "code", + "unknown": "text", +} + + +@action(reads=[], writes=["chat_history", "prompt"]) +def process_prompt(state: State, prompt: str) -> State: + result = {"chat_item": {"role": "user", "content": prompt, "type": "text"}} + return ( + state.wipe(keep=["prompt", "chat_history"]) + .append(chat_history=result["chat_item"]) + .update(prompt=prompt) + ) + + +@action(reads=[], writes=["has_openai_key"]) +def check_openai_key(state: State) -> State: + result = {"has_openai_key": "OPENAI_API_KEY" in os.environ} + return state.update(**result) + + +@action(reads=["prompt"], writes=["safe"]) +def check_safety(state: State) -> State: + result = {"safe": "unsafe" not in state["prompt"]} # quick hack to demonstrate + return state.update(safe=result["safe"]) + + +# def _get_openai_client(): +# return openai.Client() + + +@action(reads=["prompt"], writes=["mode"]) +def choose_mode(state: State, client: openai.OpenAI) -> State: + prompt = ( + f"You are a chatbot. You've been prompted this: {state['prompt']}. " + f"You have the capability of responding in the following modes: {', '.join(MODES)}. " + "Please respond with *only* a single word representing the mode that most accurately " + "corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', " + "the mode would be 'generate_image'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'." + "If none of these modes apply, please respond with 'unknown'." + ) + + result = client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt}, + ], + ) + content = result.choices[0].message.content + mode = content.lower() + if mode not in MODES: + mode = "unknown" + result = {"mode": mode} + return state.update(**result) + + +@action(reads=["prompt", "chat_history"], writes=["response"]) +def prompt_for_more(state: State) -> State: + result = { + "response": { + "content": "None of the response modes I support apply to your question. Please clarify?", + "type": "text", + "role": "assistant", + } + } + return state.update(**result) + + +@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) +def chat_response( + state: State, + prepend_prompt: str, + client: openai.OpenAI, + display_type: str = "text", + model: str = "gpt-3.5-turbo", +) -> State: + chat_history = copy.deepcopy(state["chat_history"]) + chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}" + chat_history_api_format = [ + { + "role": chat["role"], + "content": chat["content"], + } + for chat in chat_history + ] + # client = _get_openai_client() + result = client.chat.completions.create( + model=model, + messages=chat_history_api_format, + ) + response = result.choices[0].message.content + result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} + return state.update(**result) + + +@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) +def image_response(state: State, client: openai.OpenAI, model: str = "dall-e-2") -> State: + """Generates an image response to the prompt. Optional save function to save the image to a URL.""" + # client = _get_openai_client() + result = client.images.generate( + model=model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1 + ) + response = result.data[0].url + result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} + return state.update(**result) + + +@action(reads=["response", "mode", "safe", "has_openai_key"], writes=["chat_history"]) +def response(state: State) -> State: + if not state["has_openai_key"]: + result = { + "chat_item": { + "role": "assistant", + "content": "You have not set an API key for [OpenAI](https://www.openai.com). Do this " + "by setting the environment variable `OPENAI_API_KEY` to your key. " + "You can get a key at [OpenAI](https://platform.openai.com). " + "You can still look at chat history/examples.", + "type": "error", + } + } + elif not state["safe"]: + result = { + "chat_item": { + "role": "assistant", + "content": "I'm sorry, I can't respond to that.", + "type": "error", + } + } + else: + result = {"chat_item": state["response"]} + return state.append(chat_history=result["chat_item"]) + + +base_graph = ( + graph.GraphBuilder() + .with_actions( + prompt=process_prompt, + check_openai_key=check_openai_key, + check_safety=check_safety, + decide_mode=choose_mode, + generate_image=image_response, + generate_code=chat_response.bind( + prepend_prompt="Please respond with *only* code and no other text (at all) to the following:", + ), + answer_question=chat_response.bind( + prepend_prompt="Please answer the following question:", + ), + prompt_for_more=prompt_for_more, + response=response, + ) + .with_transitions( + ("prompt", "check_openai_key", default), + ("check_openai_key", "check_safety", when(has_openai_key=True)), + ("check_openai_key", "response", default), + ("check_safety", "decide_mode", when(safe=True)), + ("check_safety", "response", default), + ("decide_mode", "generate_image", when(mode="generate_image")), + ("decide_mode", "generate_code", when(mode="generate_code")), + ("decide_mode", "answer_question", when(mode="answer_question")), + ("decide_mode", "prompt_for_more", default), + ( + ["generate_image", "answer_question", "generate_code", "prompt_for_more"], + "response", + ), + ("response", "prompt", default), + ) + .build() +) + + +def base_application( + hooks: List[LifecycleAdapter], + app_id: str, + storage_dir: str, + project_id: str, +): + if hooks is None: + hooks = [] + # we're initializing above so we can load from this as well + # we could also use `with_tracker("local", project=project_id, params={"storage_dir": storage_dir})` + tracker = LocalTrackingClient(project=project_id, storage_dir=storage_dir) + return ( + ApplicationBuilder() + .with_graph(base_graph) + # initializes from the tracking log if it does not already exist + .initialize_from( + tracker, + resume_at_next_action=False, # always resume from entrypoint in the case of failure + default_state={"chat_history": []}, + default_entrypoint="prompt", + ) + .with_hooks(*hooks) + .with_tracker(tracker) + .with_identifiers(app_id=app_id) + .build() + ) + + +def application( + app_id: Optional[str] = None, + project_id: str = "test_proxy", + storage_dir: Optional[str] = "~/.burr", + hooks: Optional[List[LifecycleAdapter]] = None, +) -> Application: + return base_application(hooks, app_id, storage_dir, project_id=project_id) + + +if __name__ == "__main__": + app = application(hooks=[ProxyClassHook()]) + # app.visualize( + # output_file_path="statemachine", include_conditions=False, view=True, format="png" + # ) + print( + app.run( + halt_after=["response"], + inputs={"prompt": "Who was Aaron Burr, sir?", "client": openai.OpenAI()}, + ) + )