blob: 9022daf78203f596e926ced8f311af40eff2e716 [file] [log] [blame]
import functools
import logging
import sys
import time as py_time
import traceback
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional
from hamilton_sdk.tracking import stats
from hamilton_sdk.tracking.trackingtypes import DAGRun, Status, TaskRun
from hamilton import node as h_node
from hamilton.data_quality import base as dq_base
from hamilton.lifecycle import base as lifecycle_base
try:
from hamilton_sdk.tracking import numpy_stats # noqa: F401
from hamilton_sdk.tracking import pandas_stats # noqa: F401
except ImportError:
pass
try:
from hamilton_sdk.tracking import polars_stats # noqa: F401
except ImportError:
pass
logger = logging.getLogger(__name__)
def process_result(result: Any, node: h_node.Node) -> Any:
"""Processes result -- this is purely a by-type mapping.
Note that this doesn't actually do anything yet -- the idea is that we can return DQ
results, and do other stuff with other results -- E.G. summary stats on dataframes,
pass small strings through, etc...
The return type is left as Any for now, but we should probably make it a union of
types that we support.
Note this should keep the cardinality of the output as low as possible.
These results will be used on the FE to display results, and we don't want
to crowd out storage.
:param result: The result of the node's execution
:param node: The node that produced the result
:return: The processed result - it has to be JSON serializable!
"""
try:
start = py_time.time()
statistics = stats.compute_stats(result, node.name, node.tags)
end = py_time.time()
logger.debug(f"Took {end - start} seconds to describe {node.name}")
return statistics
# TODO: introspect other nodes
# if it's a check_output node, then we want to process the pandera result/the result from it.
except Exception as e:
logger.warning(f"Failed to introspect result for {node.name}. Error:\n{e}")
class TrackingState:
"""Mutable class that tracks data"""
def __init__(self, run_id: str):
"""Initializes the tracking state"""
self.status = Status.UNINITIALIZED
self.start_time = None
self.end_time = None
self.run_id = run_id
self.task_map: Dict[str, TaskRun] = {}
self.update_status(Status.UNINITIALIZED)
def clock_start(self):
"""Called at start of run"""
logger.info("Clocked beginning of run")
self.status = Status.RUNNING
self.start_time = datetime.now(timezone.utc)
def clock_end(self, status: Status):
"""Called at end of run"""
logger.info(f"Clocked end of run with status: {status}")
self.end_time = datetime.now(timezone.utc)
self.status = status
def update_task(self, task_name: str, task_run: TaskRun):
"""Updates a task"""
self.task_map.update({task_name: task_run})
logger.debug(f"Updating task: {task_name} with data: {task_run}")
def update_status(self, status: Status):
"""Updates the status of the run"""
self.status = status
logger.info(f"Updating run status with value: {status}")
def get(self) -> DAGRun:
"""Gives the final result as a DAG run"""
return DAGRun(
run_id=self.run_id,
status=self.status,
# TODO -- think about using a json dumper and referring to this as a status
tasks=list(self.task_map.values()),
start_time=self.start_time,
end_time=self.end_time,
schema_version=0,
)
def serialize_error() -> List[str]:
"""Serialize an error to a string.
Note we should probably have this *unparsed*, so we can display in the UI,
but its OK for now to just have the string.
*note* this has to be called from within an except block.
:param error:
:return:
"""
exc_type, exc_value, exc_tb = sys.exc_info()
return traceback.format_exception(exc_type, exc_value, exc_tb)
def serialize_data_quality_error(e: dq_base.DataValidationError) -> List[str]:
"""Santizes data quality errors to make them more readable for the platform.
Note: this is hacky code.
:param e: Data quality error to inspect
:return: list of failures.
"""
validation_failures = e.args[0]
sanitized_failures = []
for failure in validation_failures:
if "pandera_schema_validator" in failure: # hack to know what type of validator.
sanitized_failures.append(failure.split("Usage Tip")[0]) # remove usage tip
else:
sanitized_failures.append(failure)
return sanitized_failures
class RunTracker:
"""This class allows you to track results of runs"""
def __init__(self, tracking_state: TrackingState):
"""Tracks runs given run IDs. Note that this needs to be re-initialized
on each run, we'll want to fix that.
:param result_builder: Result builder to use
:param run_id: Run ID to save with
"""
self.tracking_state = tracking_state
def execute_node(
self,
original_do_node_execute: Callable,
run_id: str,
node_: h_node.Node,
kwargs: Dict[str, Any],
task_id: Optional[str],
) -> Any:
"""Given a node that represents a hamilton function, execute it.
Note, in some adapters this might just return some type of "future".
:param node_: the Hamilton Node
:param kwargs: the kwargs required to exercise the node function.
:return: the result of exercising the node.
:param original_execute_node: The original adapter's callable
"""
logger.debug(f"Executing node: {node_.name}")
# If the hamilton_tracking state hasn't started
if self.tracking_state.status == Status.UNINITIALIZED:
self.tracking_state.update_status(Status.RUNNING)
task_run = TaskRun(node_name=node_.name) # node run.
task_run.status = Status.RUNNING
task_run.start_time = datetime.now(timezone.utc)
self.tracking_state.update_task(node_.name, task_run)
try:
result = original_do_node_execute(run_id, node_, kwargs, task_id)
task_run.status = Status.SUCCESS
task_run.result_type = type(result)
task_run.result_summary = process_result(result, node_) # add node
task_run.end_time = datetime.now(timezone.utc)
self.tracking_state.update_task(node_.name, task_run)
logger.debug(f"Node: {node_.name} ran successfully")
return result
except dq_base.DataValidationError as e:
task_run.status = Status.FAILURE
task_run.end_time = datetime.now(timezone.utc)
task_run.error = serialize_data_quality_error(e)
self.tracking_state.update_status(Status.FAILURE)
self.tracking_state.update_task(node_.name, task_run)
logger.debug(f"Node: {node_.name} encountered data quality issue...")
raise e
except Exception as e:
task_run.status = Status.FAILURE
task_run.end_time = datetime.now(timezone.utc)
task_run.error = serialize_error()
self.tracking_state.update_status(Status.FAILURE)
self.tracking_state.update_task(node_.name, task_run)
logger.debug(f"Node: {node_.name} failed to run...")
raise e
@contextmanager
def monkey_patch_adapter(
adapter: lifecycle_base.LifecycleAdapterSet, tracking_state: TrackingState
):
"""Monkey patches the graph adapter to track the results o fthe run
:param adapter: Adapter to modify the execute_node functionality
:param tracking_state: State of the DAG -- used for tracking
"""
(adapter_to_patch,) = [
item for item in adapter.adapters if hasattr(item, "do_node_execute")
] # Have to patch it
original_do_node_execute = adapter_to_patch.do_node_execute
try:
run_tracker = RunTracker(tracking_state=tracking_state)
# monkey patch the adapter
adapter_to_patch.do_node_execute = functools.partial(
run_tracker.execute_node, original_do_node_execute=original_do_node_execute
)
yield
finally:
adapter_to_patch.do_node_execute = original_do_node_execute