WIP: adds first stab at middleware plugin 1. need to test 2. need to figure out how to get the span 3. fix up serde
diff --git a/hamilton/plugins/h_middleware.py b/hamilton/plugins/h_middleware.py new file mode 100644 index 0000000..8613728 --- /dev/null +++ b/hamilton/plugins/h_middleware.py
@@ -0,0 +1,284 @@ +import logging +from typing import Any, Dict, Optional + +from hamilton import lifecycle + +logger = logging.getLogger(__name__) +try: + from middleware import MwTracker + from opentelemetry import trace + + tracker = MwTracker() +except ImportError as e: + logger.error("ImportError: %s", e) + logger.error( + "To use the h_middleware plugin, please install sf-hamilton[middleware] using " + "`pip install sf-hamilton[middleware]` (or use your favorite package manager)." + "Remember to use quotes around the package name if using zsh!" + ) + raise + +tracer = trace.get_tracer(__name__) # not sure where this should go exactly. +# maybe we should just do span = trace.get_current_span() ? + + +class MiddlewareTracer( + lifecycle.NodeExecutionHook, lifecycle.GraphExecutionHook, lifecycle.TaskExecutionHook +): + """Lifecycle adapter to use datadog to run tracing on node execution. This works with the following execution environments: + 1. Vanilla Hamilton -- no task-based computation, just nodes + 2. Task-based, synchronous + 3. Task-based with Multithreading, Ray, and Dask + It will likely work with others, although we have not yet tested them. This does not work with async (yet). + + Note that this is not a typical use of Datadog if you're not using hamilton for a microservice. It does work quite nicely, however! + Monitoring ETLs is not a typical datadog case (you can't see relationships between nodes/tasks or data summaries), + but it is easy enough to work with and gives some basic information. + + This tracer bypasses context management so we can more accurately track relationships between nodes/tags. Also, we plan to + get this working with OpenTelemetry, and use that for datadog integration. + + To use this, you'll want to run `pip install sf-hamilton[middleware]` (or `pip install "sf-hamilton[middleware]"` if using zsh) + """ + + def __init__(self, root_name: str, include_causal_links: bool = False, service: str = None): + """Creates a MiddlewareTracer. This has the option to specify some parameters. + + :param root_name: Name of the root trace/span. Due to the way datadog inherits, this will inherit an active span. + :param include_causal_links: Whether or not to include span causal links. Note that there are some edge-cases here, and + This is in beta for datadog, and actually broken in the current client, but it has been fixed and will be released shortly: + https://github.com/DataDog/dd-trace-py/issues/8049. Furthermore, the query on datadog is slow for displaying causal links. + We've disabled this by default, but feel free to test it out -- its likely they'll be improving the docum + :param service: Service name -- will pick it up from the environment through DDOG if not available. + """ + self.root_name = root_name + self.service = service + self.include_causal_links = include_causal_links + self.run_span_cache = {} # Cache of run_id -> span tuples + self.task_span_cache = ( + {} + ) # cache of run_iod -> task_id -> span. Note that we will prune this after task execution + self.node_span_cache = ( + {} + ) # Cache of run_id -> [task_id, node_id] -> span. We use this to open/close general traces + + # @staticmethod + # def _serialize_span_dict(span_dict: Dict[str, Span]): + # """Serializes to a readable format. We're not propogating span links (see note above on causal links), + # but that's fine (for now). We have to do this as passing spans back and forth is frowned upon. + # + # :param span_dict: A key -> span dictionary + # :return: The serialized representation. + # """ + # # For some reason this doesn't use the right ser/deser for dask + # # Or for some reason it has contexts instead of spans. Well, we can serialize them both! + # return { + # key: { + # "trace_id": span.context.trace_id if isinstance(span, Span) else span.trace_id, + # "span_id": span.context.span_id if isinstance(span, Span) else span.span_id, + # } + # for key, span in span_dict.items() + # } + # + # @staticmethod + # def _deserialize_span_dict(serialized_repr: Dict[str, dict]) -> Dict[str, context.Context]: + # """Note that we deserialize as contexts, as passing spans is not supported + # (the child should never terminate the parent span). + # + # :param span_dict: Dict of str -> dict params for contexts + # :return: A dictionary of contexts + # """ + # return {key: context.Context(**params) for key, params in serialized_repr.items()} + # + # def __getstate__(self): + # """Gets the state for serialization""" + # return dict( + # root_trace_name=self.root_name, + # service=self.service, + # include_causal_links=self.include_causal_links, + # run_span_cache=self._serialize_span_dict(self.run_span_cache), + # task_span_cache={ + # key: self._serialize_span_dict(value) for key, value in self.task_span_cache.items() + # }, + # # this is unnecessary, but leaving it here for now + # # to remove it, we need to add a default check in the one that adds to the nodes + # node_span_cache={ + # key: self._serialize_span_dict(value) for key, value in self.task_span_cache.items() + # }, # Nothing here, we can just wipe it for a new task + # ) + # + # def __setstate__(self, state): + # """Sets the state for serialization""" + # self.service = state["service"] + # self.root_name = state["root_trace_name"] + # self.include_causal_links = state["include_causal_links"] + # # TODO -- move this out/consider doing it to the others + # self.run_span_cache = self._deserialize_span_dict(state["run_span_cache"]) + # # We only really need this if we log the stuff before submitting... + # # This shouldn't happen but it leaves flexibility for the future + # self.task_span_cache = { + # key: self._deserialize_span_dict(value) + # for key, value in state["task_span_cache"].items() + # } + # self.node_span_cache = { + # key: self._deserialize_span_dict(value) + # for key, value in state["node_span_cache"].items() + # } + + @staticmethod + def _sanitize_tags(tags: Dict[str, Any]) -> Dict[str, str]: + """Sanitizes tags to be strings, just in case. + + :param tags: Node tags. + :return: The string -> string representation of tags + """ + return {f"hamilton.{key}": str(value) for key, value in tags.items()} + + def run_before_graph_execution(self, *, run_id: str, **future_kwargs: Any): + """Runs before graph execution -- sets the state so future ones can reference it. + + :param run_id: ID of the run + :param future_kwargs: reserved for future keyword arguments/backwards compatibility. + """ + span = tracer.start_span(name=self.root_name) + # span = tracer.start_span(name=self.root_name) + self.run_span_cache[run_id] = span # we save this as a root span + self.node_span_cache[run_id] = {} + self.task_span_cache[run_id] = {} + + def run_before_node_execution( + self, + *, + node_name: str, + node_kwargs: Dict[str, Any], + node_tags: Dict[str, Any], + task_id: Optional[str], + run_id: str, + **future_kwargs: Any, + ): + """Runs before a node's execution. Sets up/stores spans. + + :param node_name: Name of the node. + :param node_kwargs: Keyword arguments of the node. + :param node_tags: Tags of the node (they'll get stored as datadog tags) + :param task_id: Task ID that spawned the node + :param run_id: ID of the run. + :param future_kwargs: reserved for future keyword arguments/backwards compatibility. + """ + # We need to do this on launching tasks and we have not yet exposed it. + # TODO -- do pre-task and post-task execution. + parent_span = self.task_span_cache[run_id].get(task_id) or self.run_span_cache[run_id] + new_span_name = f"{task_id}:" if task_id is not None else "" + new_span_name += node_name + + links = None + if self.include_causal_links: + links = [] + prior_spans = { + key: self.node_span_cache[run_id].get((task_id, key)) for key in node_kwargs + } + for input_node, span in prior_spans.items(): + if span is not None: + links.append( + trace.Link( + context=span.context, + attributes={ + "link.name": f"{input_node}_to_{node_name}", + }, + ) + ) + new_span = self.run_span_cache[run_id].start_span( + name=new_span_name, context=parent_span, links=links + ) + + tags = node_tags.copy() + tags["hamilton.node_name"] = node_name + new_span.set_tags(MiddlewareTracer._sanitize_tags(tags=tags)) + self.node_span_cache[run_id][(task_id, node_name)] = new_span + + def run_after_node_execution( + self, + *, + node_name: str, + error: Optional[Exception], + task_id: Optional[str], + run_id: str, + **future_kwargs: Any, + ): + """Runs after a node's execution -- completes the span. + + :param node_name: Name of the node + :param error: Error that the node raised, if any + :param task_id: Task ID that spawned the node + :param run_id: ID of the run. + :param future_kwargs: reserved for future keyword arguments/backwards compatibility. + """ + span = self.node_span_cache[run_id][(task_id, node_name)] + exc_type = None + exc_value = None + tb = None + if error is not None: + exc_type = type(error) + exc_value = error + tb = error.__traceback__ + span.__exit__(exc_type, exc_value, tb) + + def run_after_graph_execution( + self, *, error: Optional[Exception], run_id: str, **future_kwargs: Any + ): + """Runs after graph execution. Garbage collects + finishes the root span. + + :param error: Error the graph raised when running, if any + :param run_id: ID of the run + :param future_kwargs: reserved for future keyword arguments/backwards compatibility. + """ + span = self.run_span_cache[run_id] + exc_type = None + exc_value = None + tb = None + if error is not None: + exc_type = type(error) + exc_value = error + tb = error.__traceback__ + span.__exit__(exc_type, exc_value, tb) + del self.run_span_cache[run_id] + del self.node_span_cache[run_id] + del self.task_span_cache[run_id] + + def run_before_task_execution(self, *, task_id: str, run_id: str, **future_kwargs): + """Runs before task execution. Sets up the task span. + + :param task_id: ID of the task + :param run_id: ID of the run, + :param future_kwargs: reserved for future keyword arguments/backwards compatibility. + """ + parent_span = self.run_span_cache[run_id] + self.task_span_cache[run_id][task_id] = tracer.start_span( + name=task_id, + context=parent_span, # span or context both work + ) + + def run_after_task_execution( + self, + *, + task_id: str, + run_id: str, + error: Exception, + **future_kwargs, + ): + """Rusn after task execution. Finishes task-level spans. + + :param task_id: ID of the task, ID of the run. + :param run_id: ID of the run + :param error: Error the graph raised when running, if any + :param future_kwargs: Future keyword arguments for backwards compatibility + """ + span = self.task_span_cache[run_id][task_id] + exc_type = None + exc_value = None + tb = None + if error is not None: + exc_type = type(error) + exc_value = error + tb = error.__traceback__ + span.__exit__(exc_type, exc_value, tb)