blob: b012fd2791c17d4c7e1c8df8346fedc4c7f37a42 [file]
import functools
import logging
import sys
import time
# required if we want to run this code stand alone.
import typing
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from types import ModuleType
from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Union
import pandas as pd
SLACK_ERROR_MESSAGE = (
"-------------------------------------------------------------------\n"
"Oh no an error! Need help with Hamilton?\n"
"Join our slack and ask for help! https://join.slack.com/t/hamilton-opensource/shared_invite/zt-1bjs72asx-wcUTgH7q7QX1igiQ5bbdcg\n"
"-------------------------------------------------------------------\n"
)
if __name__ == "__main__":
import base
import graph
import node
import telemetry
else:
from . import base, graph, node, telemetry
logger = logging.getLogger(__name__)
def capture_function_usage(call_fn: Callable) -> Callable:
"""Decorator to wrap some driver functions for telemetry capture.
We want to use this for non-constructor and non-execute functions.
We don't capture information about the arguments at this stage,
just the function name.
:param call_fn: the Driver function to capture.
:return: wrapped function.
"""
@functools.wraps(call_fn)
def wrapped_fn(*args, **kwargs):
try:
return call_fn(*args, **kwargs)
finally:
if telemetry.is_telemetry_enabled():
try:
function_name = call_fn.__name__
event_json = telemetry.create_driver_function_invocation_event(function_name)
telemetry.send_event_json(event_json)
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
logger.error(
f"Failed to send telemetry for function usage. Encountered:{e}\n"
)
return wrapped_fn
@dataclass
class Variable:
"""External facing API for hamilton. Having this as a dataclass allows us
to hide the internals of the system but expose what the user might need.
Furthermore, we can always add attributes and maintain backwards compatibility."""
name: str
type: typing.Type
tags: Dict[str, str] = field(default_factory=frozenset)
class Driver(object):
"""This class orchestrates creating and executing the DAG to create a dataframe."""
def __init__(
self,
config: Dict[str, Any],
*modules: ModuleType,
adapter: base.HamiltonGraphAdapter = None,
):
"""Constructor: creates a DAG given the configuration & modules to crawl.
:param config: This is a dictionary of initial data & configuration.
The contents are used to help create the DAG.
:param modules: Python module objects you want to inspect for Hamilton Functions.
:param adapter: Optional. A way to wire in another way of "executing" a hamilton graph.
Defaults to using original Hamilton adapter which is single threaded in memory python.
"""
self.driver_run_id = uuid.uuid4()
if adapter is None:
adapter = base.SimplePythonDataFrameGraphAdapter()
error = None
self.graph_modules = modules
try:
self.graph = graph.FunctionGraph(*modules, config=config, adapter=adapter)
self.adapter = adapter
except Exception as e:
error = telemetry.sanitize_error(*sys.exc_info())
logger.error(SLACK_ERROR_MESSAGE)
raise e
finally:
self.capture_constructor_telemetry(error, modules, config, adapter)
def capture_constructor_telemetry(
self,
error: Optional[str],
modules: Tuple[ModuleType],
config: Dict[str, Any],
adapter: base.HamiltonGraphAdapter,
):
"""Captures constructor telemetry.
Notes:
(1) we want to do this in a way that does not break.
(2) we need to account for all possible states, e.g. someone passing in None, or assuming that
the entire constructor code ran without issue, e.g. `adpater` was assigned to `self`.
:param error: the sanitized error string to send.
:param modules: the list of modules, could be None.
:param config: the config dict passed, could be None.
:param adapter: the adapter passed in, might not be attached to `self` yet.
"""
if telemetry.is_telemetry_enabled():
try:
adapter_name = telemetry.get_adapter_name(adapter)
result_builder = telemetry.get_result_builder_name(adapter)
# being defensive here with ensuring values exist
payload = telemetry.create_start_event_json(
len(self.graph.nodes) if hasattr(self, "graph") else 0,
len(modules) if modules else 0,
len(config) if config else 0,
dict(self.graph.decorator_counter) if hasattr(self, "graph") else {},
adapter_name,
result_builder,
self.driver_run_id,
error,
)
telemetry.send_event_json(payload)
except Exception as e:
# we don't want this to fail at all!
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Error caught in processing telemetry: {e}")
def _node_is_required_by_anything(self, node_: node.Node, node_set: Set[node.Node]) -> bool:
"""Checks dependencies on this node and determines if at least one requires it.
Nodes can be optionally depended upon, i.e. the function parameter has a default value. We want to check that
of the nodes the depend on this one, at least one of them requires it, i.e. the parameter is not optional.
:param node_: node in question
:param node_set: checks that we traverse only nodes in the provided set.
:return: True if it is required by any downstream node, false otherwise
"""
required = False
for downstream_node in node_.depended_on_by:
if downstream_node not in node_set:
continue
_, dep_type = downstream_node.input_types[node_.name]
if dep_type == node.DependencyType.REQUIRED:
return True
return required
def validate_inputs(
self,
user_nodes: Collection[node.Node],
inputs: typing.Optional[Dict[str, Any]] = None,
nodes_set: Collection[node.Node] = None,
):
"""Validates that inputs meet our expectations. This means that:
1. The runtime inputs don't clash with the graph's config
2. All expected graph inputs are provided, either in config or at runtime
:param user_nodes: The required nodes we need for computation.
:param inputs: the user inputs provided.
:param nodes_set: the set of nodes to use for validation; Optional.
"""
if inputs is None:
inputs = {}
if nodes_set is None:
nodes_set = set(self.graph.nodes.values())
(all_inputs,) = (graph.FunctionGraph.combine_config_and_inputs(self.graph.config, inputs),)
errors = []
for user_node in user_nodes:
if user_node.name not in all_inputs:
if self._node_is_required_by_anything(user_node, nodes_set):
errors.append(
f"Error: Required input {user_node.name} not provided "
f"for nodes: {[node.name for node in user_node.depended_on_by]}."
)
elif all_inputs[user_node.name] is not None and not self.adapter.check_input_type(
user_node.type, all_inputs[user_node.name]
):
errors.append(
f"Error: Type requirement mismatch. Expected {user_node.name}:{user_node.type} "
f"got {all_inputs[user_node.name]} instead."
)
if errors:
errors.sort()
error_str = f"{len(errors)} errors encountered:\n " + "\n ".join(errors)
raise ValueError(error_str)
def execute(
self,
final_vars: List[Union[str, Callable]],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
) -> Any:
"""Executes computation.
:param final_vars: the final list of outputs we want to compute.
:param overrides: values that will override "nodes" in the DAG.
:param display_graph: DEPRECATED. Whether we want to display the graph being computed.
:param inputs: Runtime inputs to the DAG.
:return: an object consisting of the variables requested, matching the type returned by the GraphAdapter.
See constructor for how the GraphAdapter is initialized. The default one right now returns a pandas
dataframe.
"""
if display_graph:
logger.warning(
"display_graph=True is deprecated. It will be removed in the 2.0.0 release. "
"Please use visualize_execution()."
)
start_time = time.time()
run_successful = True
error = None
_final_vars = self._create_final_vars(final_vars)
try:
outputs = self.raw_execute(_final_vars, overrides, display_graph, inputs=inputs)
result = self.adapter.build_result(**outputs)
return result
except Exception as e:
run_successful = False
logger.error(SLACK_ERROR_MESSAGE)
error = telemetry.sanitize_error(*sys.exc_info())
raise e
finally:
duration = time.time() - start_time
self.capture_execute_telemetry(
error, _final_vars, inputs, overrides, run_successful, duration
)
def _create_final_vars(self, final_vars: List[Union[str, Callable]]) -> List[str]:
"""Creates the final variables list - converting functions names as required.
:param final_vars:
:return: list of strings in the order that final_vars was provided.
"""
_final_vars = []
errors = []
module_set = {_module.__name__ for _module in self.graph_modules}
for final_var in final_vars:
if isinstance(final_var, str):
_final_vars.append(final_var)
elif isinstance(final_var, Callable):
if final_var.__module__ in module_set:
_final_vars.append(final_var.__name__)
else:
errors.append(
f"Function {final_var.__module__}.{final_var.__name__} is a function not in a "
f"module given to the driver. Valid choices are {module_set}."
)
else:
errors.append(f"Final var {final_var} is not a string or a function.")
if errors:
errors.sort()
error_str = f"{len(errors)} errors encountered:\n " + "\n ".join(errors)
raise ValueError(error_str)
return _final_vars
def capture_execute_telemetry(
self,
error: Optional[str],
final_vars: List[str],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
run_successful: bool,
duration: float,
):
"""Captures telemetry after execute has run.
Notes:
(1) we want to be quite defensive in not breaking anyone's code with things we do here.
(2) thus we want to double-check that values exist before doing something with them.
:param error: the sanitized error string to capture, if any.
:param final_vars: the list of final variables to get.
:param inputs: the inputs to the execute function.
:param overrides: any overrides to the execute function.
:param run_successful: whether this run was successful.
:param duration: time it took to run execute.
"""
if telemetry.is_telemetry_enabled():
try:
payload = telemetry.create_end_event_json(
run_successful,
duration,
len(final_vars) if final_vars else 0,
len(overrides) if isinstance(overrides, Dict) else 0,
len(inputs) if isinstance(overrides, Dict) else 0,
self.driver_run_id,
error,
)
telemetry.send_event_json(payload)
except Exception as e:
# we don't want this to fail at all!
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Error caught in processing telemetry:\n{e}")
def raw_execute(
self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
) -> Dict[str, Any]:
"""Raw execute function that does the meat of execute.
It does not try to stitch anything together. Thus allowing wrapper executes around this to shape the output
of the data.
:param final_vars: Final variables to compute
:param overrides: Overrides to run.
:param display_graph: DEPRECATED. DO NOT USE. Whether or not to display the graph when running it
:param inputs: Runtime inputs to the DAG
:return:
"""
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs)
self.validate_inputs(
user_nodes, inputs, nodes
) # TODO -- validate within the function graph itself
if display_graph: # deprecated flow.
logger.warning(
"display_graph=True is deprecated. It will be removed in the 2.0.0 release. "
"Please use visualize_execution()."
)
self.visualize_execution(final_vars, "test-output/execute.gv", {"view": True})
if self.has_cycles(final_vars): # here for backwards compatible driver behavior.
raise ValueError("Error: cycles detected in you graph.")
memoized_computation = dict() # memoized storage
self.graph.execute(nodes, memoized_computation, overrides, inputs)
outputs = {
c: memoized_computation[c] for c in final_vars
} # only want request variables in df.
del memoized_computation # trying to cleanup some memory
return outputs
@capture_function_usage
def list_available_variables(self) -> List[Variable]:
"""Returns available variables, i.e. outputs.
:return: list of available variables (i.e. outputs).
"""
return [Variable(node.name, node.type, node.tags) for node in self.graph.get_nodes()]
@capture_function_usage
def display_all_functions(
self, output_file_path: str, render_kwargs: dict = None, graphviz_kwargs: dict = None
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays the graph of all functions loaded!
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph-all.dot'
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options.
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
try:
return self.graph.display_all(output_file_path, render_kwargs, graphviz_kwargs)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
@capture_function_usage
def visualize_execution(
self,
final_vars: List[str],
output_file_path: str,
render_kwargs: dict,
inputs: Dict[str, Any] = None,
graphviz_kwargs: dict = None,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes Execution.
Note: overrides are not handled at this time.
:param final_vars: the outputs we want to compute.
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph.dot'
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options.
:param inputs: Optional. Runtime inputs to the DAG.
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
_final_vars = self._create_final_vars(final_vars)
nodes, user_nodes = self.graph.get_upstream_nodes(_final_vars, inputs)
self.validate_inputs(user_nodes, inputs, nodes)
try:
return self.graph.display(
nodes,
user_nodes,
output_file_path,
render_kwargs=render_kwargs,
graphviz_kwargs=graphviz_kwargs,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
@capture_function_usage
def has_cycles(self, final_vars: List[str]) -> bool:
"""Checks that the created graph does not have cycles.
:param final_vars: the outputs we want to compute.
:return: boolean True for cycles, False for no cycles.
"""
_final_vars = self._create_final_vars(final_vars)
# get graph we'd be executing over
nodes, user_nodes = self.graph.get_upstream_nodes(_final_vars)
return self.graph.has_cycles(nodes, user_nodes)
@capture_function_usage
def what_is_downstream_of(self, *node_names: str) -> List[Variable]:
"""Tells you what is downstream of this function(s), i.e. node(s).
:param node_names: names of function(s) that are starting points for traversing the graph.
:return: list of "variables" (i.e. nodes), inclusive of the function names, that are downstream of the passed
in function names.
"""
downstream_nodes = self.graph.get_impacted_nodes(list(node_names))
return [Variable(node.name, node.type, node.tags) for node in downstream_nodes]
@capture_function_usage
def display_downstream_of(
self, *node_names: str, output_file_path: str, render_kwargs: dict, graphviz_kwargs: dict
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG starting from the passed in function name(s).
Note: for any "node" visualized, we will also add its parents to the visualization as well, so
there could be more nodes visualized than strictly what is downstream of the passed in function name(s).
:param node_names: names of function(s) that are starting points for traversing the graph.
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph.dot'. Pass in None to skip saving any file.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
:param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
downstream_nodes = self.graph.get_impacted_nodes(list(node_names))
try:
return self.graph.display(
downstream_nodes,
set(),
output_file_path,
render_kwargs=render_kwargs,
graphviz_kwargs=graphviz_kwargs,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
@capture_function_usage
def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
"""Tells you what is upstream of this function(s), i.e. node(s).
:param node_names: names of function(s) that are starting points for traversing the graph backwards.
:return: list of "variables" (i.e. nodes), inclusive of the function names, that are upstream of the passed
in function names.
"""
upstream_nodes, _ = self.graph.get_upstream_nodes(list(node_names))
return [Variable(node.name, node.type, node.tags) for node in upstream_nodes]
if __name__ == "__main__":
"""some example test code"""
import importlib
formatter = logging.Formatter("[%(levelname)s] %(asctime)s %(name)s(%(lineno)s): %(message)s")
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)
if len(sys.argv) < 2:
logger.error("No modules passed")
sys.exit(1)
logger.info(f"Importing {sys.argv[1]}")
module = importlib.import_module(sys.argv[1])
x = pd.date_range("2019-01-05", "2020-12-31", freq="7D")
x.index = x
dr = Driver(
{
"VERSION": "kids",
"as_of": datetime.strptime("2019-06-01", "%Y-%m-%d"),
"end_date": "2020-12-31",
"start_date": "2019-01-05",
"start_date_d": datetime.strptime("2019-01-05", "%Y-%m-%d"),
"end_date_d": datetime.strptime("2020-12-31", "%Y-%m-%d"),
"segment_filters": {"business_line": "womens"},
},
module,
)
df = dr.execute(
["date_index", "some_column"]
# ,overrides={'DATE': pd.Series(0)}
,
display_graph=False,
)
print(df)