| import datetime |
| import hashlib |
| import logging |
| import os |
| import random |
| import traceback |
| from datetime import timezone |
| from types import ModuleType |
| from typing import Any, Callable, Dict, List, Optional, Union |
| |
| from hamilton_sdk import driver |
| from hamilton_sdk.api import clients, constants |
| from hamilton_sdk.tracking import runs |
| from hamilton_sdk.tracking.runs import Status, TrackingState |
| from hamilton_sdk.tracking.trackingtypes import TaskRun |
| |
| from hamilton import graph as h_graph |
| from hamilton import node |
| from hamilton.data_quality import base as dq_base |
| from hamilton.lifecycle import base |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| def get_node_name(node_: node.Node, task_id: Optional[str]) -> str: |
| if task_id is not None: |
| return f"{task_id}-{node_.name}" |
| return node_.name |
| |
| |
| LONG_SCALE = float(0xFFFFFFFFFFFFFFF) |
| |
| |
| class HamiltonTracker( |
| base.BasePostGraphConstruct, |
| base.BasePreGraphExecute, |
| base.BasePreNodeExecute, |
| base.BasePostNodeExecute, |
| base.BasePostGraphExecute, |
| ): |
| def __init__( |
| self, |
| project_id: Union[str, int], |
| username: str, |
| dag_name: str, |
| tags: Dict[str, str] = None, |
| client_factory: Callable[ |
| [str, str, str], clients.HamiltonClient |
| ] = clients.BasicSynchronousHamiltonClient, |
| api_key: str = os.environ.get( |
| "HAMILTON_API_KEY", |
| ), |
| hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL), |
| hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL), |
| create_project_if_not_exists: bool = False, |
| ): |
| """This hooks into Hamilton execution to track DAG runs in Hamilton UI. |
| |
| :param project_id: the ID of the project |
| :param username: the username for the API key. |
| :param dag_name: the name of the DAG. |
| :param tags: any tags to help curate and organize the DAG |
| :param client_factory: a factory to create the client to phone Hamilton with. |
| :param api_key: the API key to use. See us if you want to use this. |
| :param hamilton_api_url: API endpoint. |
| :param hamilton_ui_url: UI Endpoint. |
| """ |
| self.project_id = project_id |
| self.api_key = api_key |
| self.username = username |
| self.client = client_factory(api_key, username, hamilton_api_url) |
| self.initialized = False |
| self.project_version = None |
| self.base_tags = tags if tags is not None else {} |
| driver.validate_tags(self.base_tags) |
| self.dag_name = dag_name |
| self.hamilton_ui_url = hamilton_ui_url |
| logger.debug("Validating authentication against Hamilton BE API...") |
| self.client.validate_auth() |
| logger.debug(f"Ensuring project {self.project_id} exists...") |
| try: |
| self.client.project_exists(self.project_id) |
| except clients.UnauthorizedException: |
| logger.exception( |
| f"Authentication failed. Please check your username and try again. " |
| f"Username: {self.username}..." |
| ) |
| raise |
| except clients.ResourceDoesNotExistException: |
| logger.error( |
| f"Project {self.project_id} does not exist/is accessible. Please create it first in the UI! " |
| f"You can do so at {self.hamilton_ui_url}/dashboard/projects" |
| ) |
| raise |
| self.dag_template_id_cache = {} |
| self.tracking_states = {} |
| self.dw_run_ids = {} |
| self.task_runs = {} |
| super().__init__() |
| # set this to a float to sample blocks. 0.1 means 10% of blocks will be sampled. |
| # set this to an int to sample blocks by modulo. |
| self.special_parallel_sample_strategy = None |
| # set this to some constant value if you want to generate the same sample each time. |
| # if you're using a float value. |
| self.seed = None |
| |
| def post_graph_construct( |
| self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] |
| ): |
| """Registers the DAG to get an ID.""" |
| if self.seed is None: |
| self.seed = random.random() |
| logger.debug("post_graph_construct") |
| fg_id = id(graph) |
| if fg_id in self.dag_template_id_cache: |
| logger.warning("Skipping creation of DAG template as it already exists.") |
| return |
| module_hash = driver._get_modules_hash(modules) |
| vcs_info = driver._derive_version_control_info(module_hash) |
| dag_hash = driver.hash_dag(graph) |
| code_hash = driver.hash_dag_modules(graph, modules) |
| dag_template_id = self.client.register_dag_template_if_not_exists( |
| project_id=self.project_id, |
| dag_hash=dag_hash, |
| code_hash=code_hash, |
| name=self.dag_name, |
| nodes=driver._extract_node_templates_from_function_graph(graph), |
| code_artifacts=driver.extract_code_artifacts_from_function_graph( |
| graph, vcs_info, vcs_info.local_repo_base_path |
| ), |
| config=graph.config, |
| tags=self.base_tags, |
| code=driver._slurp_code(graph, vcs_info.local_repo_base_path), |
| vcs_info=vcs_info, |
| ) |
| self.dag_template_id_cache[fg_id] = dag_template_id |
| |
| def pre_graph_execute( |
| self, |
| run_id: str, |
| graph: h_graph.FunctionGraph, |
| final_vars: List[str], |
| inputs: Dict[str, Any], |
| overrides: Dict[str, Any], |
| ): |
| """Creates a DAG run.""" |
| logger.debug("pre_graph_execute %s", run_id) |
| fg_id = id(graph) |
| if fg_id in self.dag_template_id_cache: |
| dag_template_id = self.dag_template_id_cache[fg_id] |
| else: |
| raise ValueError("DAG template ID not found in cache. This should never happen.") |
| tracking_state = TrackingState(run_id) |
| self.tracking_states[run_id] = tracking_state # cache |
| tracking_state.clock_start() |
| dw_run_id = self.client.create_and_start_dag_run( |
| dag_template_id=dag_template_id, |
| tags=self.base_tags, |
| inputs=inputs if inputs is not None else {}, |
| outputs=final_vars, |
| ) |
| self.dw_run_ids[run_id] = dw_run_id |
| self.task_runs[run_id] = {} |
| logger.warning( |
| f"\nCapturing execution run. Results can be found at " |
| f"{self.hamilton_ui_url}/dashboard/project/{self.project_id}/runs/{dw_run_id}\n" |
| ) |
| |
| def pre_node_execute( |
| self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None |
| ): |
| """Captures start of node execution.""" |
| logger.debug("pre_node_execute %s %s", run_id, task_id) |
| tracking_state = self.tracking_states[run_id] |
| if tracking_state.status == Status.UNINITIALIZED: # not thread safe? |
| tracking_state.update_status(Status.RUNNING) |
| |
| in_sample = self.is_in_sample(task_id) |
| task_run = TaskRun(node_name=node_.name, is_in_sample=in_sample) |
| task_run.status = Status.RUNNING |
| task_run.start_time = datetime.datetime.now(timezone.utc) |
| tracking_state.update_task(node_.name, task_run) |
| self.task_runs[run_id][node_.name] = task_run |
| |
| task_update = dict( |
| node_template_name=node_.name, |
| node_name=get_node_name(node_, task_id), |
| realized_dependencies=[dep.name for dep in node_.dependencies], |
| status=task_run.status, |
| start_time=task_run.start_time, |
| end_time=None, |
| ) |
| # we need a 1-1 mapping of updates for the sample stuff to work. |
| self.client.update_tasks( |
| self.dw_run_ids[run_id], |
| attributes=[None], |
| task_updates=[task_update], |
| in_samples=[task_run.is_in_sample], |
| ) |
| |
| def get_hash(self, block_value: int): |
| """Creates a deterministic hash.""" |
| full_salt = "%s.%s%s" % (self.seed, "DAGWORKS", ".") |
| hash_str = "%s%s" % (full_salt, str(block_value)) |
| hash_str = hash_str.encode("ascii") |
| return int(hashlib.sha1(hash_str).hexdigest()[:15], 16) |
| |
| def get_deterministic_random(self, block_value: int): |
| """Gets a random number between 0 & 1 given the block value.""" |
| zero_to_one = self.get_hash(block_value) / LONG_SCALE |
| return zero_to_one # should be between 0 and 1 |
| |
| def is_in_sample(self, task_id: str) -> bool: |
| """Determines if what we're tracking is considered in sample. |
| |
| This should only be used at the node level right now and is intended |
| for parallel blocks that could be quick large. |
| """ |
| if ( |
| self.special_parallel_sample_strategy is not None |
| and task_id is not None |
| and task_id.startswith("expand-") |
| and "block" in task_id |
| ): |
| in_sample = False |
| block_id = int(task_id.split(".")[1]) |
| if isinstance(self.special_parallel_sample_strategy, float): |
| # if it's a float we want to sample blocks |
| if self.get_deterministic_random(block_id) < self.special_parallel_sample_strategy: |
| in_sample = True |
| elif isinstance(self.special_parallel_sample_strategy, int): |
| # if it's an int we want to take the modulo of the block id so all the |
| # nodes for a block will be captured or not. |
| if block_id % self.special_parallel_sample_strategy == 0: |
| in_sample = True |
| else: |
| raise ValueError( |
| f"Unknown special_parallel_sample_strategy: " |
| f"{self.special_parallel_sample_strategy}" |
| ) |
| else: |
| in_sample = True |
| return in_sample |
| |
| def post_node_execute( |
| self, |
| run_id: str, |
| node_: node.Node, |
| kwargs: Dict[str, Any], |
| success: bool, |
| error: Optional[Exception], |
| result: Optional[Any], |
| task_id: Optional[str] = None, |
| ): |
| """Captures end of node execution.""" |
| logger.debug("post_node_execute %s %s", run_id, task_id) |
| task_run: TaskRun = self.task_runs[run_id][node_.name] |
| tracking_state = self.tracking_states[run_id] |
| |
| other_results = [] |
| if success: |
| task_run.status = Status.SUCCESS |
| task_run.result_type = type(result) |
| result_summary = runs.process_result(result, node_) |
| if result_summary is None: |
| result_summary = { |
| "observability_type": "observability_failure", |
| "observability_schema_version": "0.0.3", |
| "observability_value": { |
| "type": str(str), |
| "value": "Failed to process result.", |
| }, |
| } |
| # NOTE This is a temporary hack to make process_result() able to return |
| # more than one object that will be used as UI "task attributes". |
| # There's a conflict between `TaskRun.result_summary` that expect a single |
| # dict from process_result() and the `HamiltonTracker.post_node_execute()` |
| # that can more freely handle "stats" to create multiple "task attributes" |
| elif isinstance(result_summary, dict): |
| result_summary = result_summary |
| elif isinstance(result_summary, list): |
| other_results = [obj for obj in result_summary[1:]] |
| result_summary = result_summary[0] |
| else: |
| raise TypeError("`process_result()` needs to return a dict or sequence of dict") |
| |
| task_run.result_summary = result_summary |
| task_attr = dict( |
| node_name=get_node_name(node_, task_id), |
| name="result_summary", |
| type=task_run.result_summary["observability_type"], |
| # 0.0.3 -> 3 |
| schema_version=int( |
| task_run.result_summary["observability_schema_version"].split(".")[-1] |
| ), |
| value=task_run.result_summary["observability_value"], |
| attribute_role="result_summary", |
| ) |
| |
| else: |
| task_run.status = Status.FAILURE |
| task_run.is_in_sample = True # override any sampling |
| if isinstance(error, dq_base.DataValidationError): |
| task_run.error = runs.serialize_data_quality_error(error) |
| else: |
| task_run.error = traceback.format_exception(type(error), error, error.__traceback__) |
| task_attr = dict( |
| node_name=get_node_name(node_, task_id), |
| name="stack_trace", |
| type="error", |
| schema_version=1, |
| value={ |
| "stack_trace": task_run.error, |
| }, |
| attribute_role="error", |
| ) |
| |
| # `result_summary` or "error" is first because the order influences UI display order |
| attributes = [task_attr] |
| for i, other_result in enumerate(other_results): |
| other_attr = dict( |
| node_name=get_node_name(node_, task_id), |
| name=other_result.get("name", f"Attribute {i+1}"), # retrieve name if specified |
| type=other_result["observability_type"], |
| # 0.0.3 -> 3 |
| schema_version=int(other_result["observability_schema_version"].split(".")[-1]), |
| value=other_result["observability_value"], |
| attribute_role="result_summary", |
| ) |
| attributes.append(other_attr) |
| |
| task_run.end_time = datetime.datetime.now(timezone.utc) |
| tracking_state.update_task(node_.name, task_run) |
| task_update = dict( |
| node_template_name=node_.name, |
| node_name=get_node_name(node_, task_id), |
| realized_dependencies=[dep.name for dep in node_.dependencies], |
| status=task_run.status, |
| start_time=task_run.start_time, |
| end_time=task_run.end_time, |
| ) |
| self.client.update_tasks( |
| self.dw_run_ids[run_id], |
| attributes=attributes, |
| task_updates=[task_update for _ in attributes], |
| in_samples=[task_run.is_in_sample for _ in attributes], |
| ) |
| |
| def post_graph_execute( |
| self, |
| run_id: str, |
| graph: h_graph.FunctionGraph, |
| success: bool, |
| error: Optional[Exception], |
| results: Optional[Dict[str, Any]], |
| ): |
| """Captures end of DAG execution.""" |
| logger.debug("post_graph_execute %s", run_id) |
| dw_run_id = self.dw_run_ids[run_id] |
| tracking_state = self.tracking_states[run_id] |
| tracking_state.clock_end(status=Status.SUCCESS if success else Status.FAILURE) |
| finally_block_time = datetime.datetime.utcnow() |
| if tracking_state.status != Status.SUCCESS: |
| # TODO: figure out how to handle crtl+c stuff |
| # -- we are at the mercy of Hamilton here. |
| tracking_state.status = Status.FAILURE |
| # this assumes the task map only has things that have been processed, not |
| # nodes that have yet to be computed. |
| for task_name, task_run in tracking_state.task_map.items(): |
| if task_run.status != Status.SUCCESS: |
| task_run.status = Status.FAILURE |
| task_run.end_time = finally_block_time |
| if task_run.error is None: # we likely aborted it. |
| # Note if we start to do concurrent execution we'll likely |
| # need to adjust this. |
| task_run.error = ["Run was likely aborted."] |
| if task_run.end_time is None and task_run.status == Status.SUCCESS: |
| task_run.end_time = finally_block_time |
| |
| self.client.log_dag_run_end( |
| dag_run_id=dw_run_id, |
| status=tracking_state.status.value, |
| ) |
| logger.warning( |
| f"\nCaptured execution run. Results can be found at " |
| f"{self.hamilton_ui_url}/dashboard/project/{self.project_id}/runs/{dw_run_id}\n" |
| ) |
| |
| |
| class AsyncHamiltonTracker( |
| base.BasePostGraphConstructAsync, |
| base.BasePreGraphExecuteAsync, |
| base.BasePreNodeExecuteAsync, |
| base.BasePostNodeExecuteAsync, |
| base.BasePostGraphExecuteAsync, |
| ): |
| def __init__( |
| self, |
| project_id: int, |
| username: str, |
| dag_name: str, |
| tags: Dict[str, str] = None, |
| client_factory: Callable[ |
| [str, str, str], clients.BasicAsynchronousHamiltonClient |
| ] = clients.BasicAsynchronousHamiltonClient, |
| api_key: str = os.environ.get("HAMILTON_API_KEY", ""), |
| hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL), |
| hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL), |
| ): |
| self.project_id = project_id |
| self.api_key = api_key |
| self.username = username |
| self.client = client_factory(api_key, username, hamilton_api_url) |
| self.initialized = False |
| self.project_version = None |
| self.base_tags = tags if tags is not None else {} |
| driver.validate_tags(self.base_tags) |
| self.dag_name = dag_name |
| self.hamilton_ui_url = hamilton_ui_url |
| self.dag_template_id_cache = {} |
| self.tracking_states = {} |
| self.dw_run_ids = {} |
| self.task_runs = {} |
| self.initialized = False |
| super().__init__() |
| |
| async def ainit(self): |
| if self.initialized: |
| return self |
| """You must call this to initialize the tracker.""" |
| logger.info("Validating authentication against Hamilton BE API...") |
| await self.client.validate_auth() |
| logger.info(f"Ensuring project {self.project_id} exists...") |
| try: |
| await self.client.project_exists(self.project_id) |
| except clients.UnauthorizedException: |
| logger.exception( |
| f"Authentication failed. Please check your username and try again. " |
| f"Username: {self.username}" |
| ) |
| raise |
| except clients.ResourceDoesNotExistException: |
| logger.error( |
| f"Project {self.project_id} does not exist/is accessible. Please create it first in the UI! " |
| f"You can do so at {self.hamilton_ui_url}/dashboard/projects" |
| ) |
| raise |
| logger.info("Initializing Hamilton tracker.") |
| await self.client.ainit() |
| logger.info("Initialized Hamilton tracker.") |
| self.initialized = True |
| return self |
| |
| async def post_graph_construct( |
| self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] |
| ): |
| logger.debug("post_graph_construct") |
| fg_id = id(graph) |
| if fg_id in self.dag_template_id_cache: |
| logger.warning("Skipping creation of DAG template as it already exists.") |
| return |
| module_hash = driver._get_modules_hash(modules) |
| vcs_info = driver._derive_version_control_info(module_hash) |
| dag_hash = driver.hash_dag(graph) |
| code_hash = driver.hash_dag_modules(graph, modules) |
| dag_template_id = await self.client.register_dag_template_if_not_exists( |
| project_id=self.project_id, |
| dag_hash=dag_hash, |
| code_hash=code_hash, |
| name=self.dag_name, |
| nodes=driver._extract_node_templates_from_function_graph(graph), |
| code_artifacts=driver.extract_code_artifacts_from_function_graph( |
| graph, vcs_info, vcs_info.local_repo_base_path |
| ), |
| config=graph.config, |
| tags=self.base_tags, |
| code=driver._slurp_code(graph, vcs_info.local_repo_base_path), |
| vcs_info=vcs_info, |
| ) |
| self.dag_template_id_cache[fg_id] = dag_template_id |
| |
| async def pre_graph_execute( |
| self, |
| run_id: str, |
| graph: h_graph.FunctionGraph, |
| final_vars: List[str], |
| inputs: Dict[str, Any], |
| overrides: Dict[str, Any], |
| ): |
| logger.debug("pre_graph_execute %s", run_id) |
| fg_id = id(graph) |
| if fg_id in self.dag_template_id_cache: |
| dag_template_id = self.dag_template_id_cache[fg_id] |
| else: |
| raise ValueError("DAG template ID not found in cache. This should never happen.") |
| |
| tracking_state = TrackingState(run_id) |
| self.tracking_states[run_id] = tracking_state # cache |
| tracking_state.clock_start() |
| dw_run_id = await self.client.create_and_start_dag_run( |
| dag_template_id=dag_template_id, |
| tags=self.base_tags, |
| inputs=inputs if inputs is not None else {}, |
| outputs=final_vars, |
| ) |
| self.dw_run_ids[run_id] = dw_run_id |
| self.task_runs[run_id] = {} |
| |
| async def pre_node_execute( |
| self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None |
| ): |
| logger.debug("pre_node_execute %s", run_id) |
| tracking_state = self.tracking_states[run_id] |
| if tracking_state.status == Status.UNINITIALIZED: # not thread safe? |
| tracking_state.update_status(Status.RUNNING) |
| |
| task_run = TaskRun(node_name=node_.name) |
| task_run.status = Status.RUNNING |
| task_run.start_time = datetime.datetime.now(timezone.utc) |
| tracking_state.update_task(node_.name, task_run) |
| self.task_runs[run_id][node_.name] = task_run |
| |
| task_update = dict( |
| node_template_name=node_.name, |
| node_name=get_node_name(node_, task_id), |
| realized_dependencies=[dep.name for dep in node_.dependencies], |
| status=task_run.status, |
| start_time=task_run.start_time, |
| end_time=None, |
| ) |
| await self.client.update_tasks( |
| self.dw_run_ids[run_id], |
| attributes=[], |
| task_updates=[task_update], |
| in_samples=[task_run.is_in_sample], |
| ) |
| |
| async def post_node_execute( |
| self, |
| run_id: str, |
| node_: node.Node, |
| success: bool, |
| error: Optional[Exception], |
| result: Any, |
| task_id: Optional[str] = None, |
| **future_kwargs, |
| ): |
| logger.debug("post_node_execute %s", run_id) |
| task_run = self.task_runs[run_id][node_.name] |
| tracking_state = self.tracking_states[run_id] |
| task_run.end_time = datetime.datetime.now(timezone.utc) |
| if success: |
| task_run.status = Status.SUCCESS |
| task_run.result_type = type(result) |
| task_run.result_summary = runs.process_result(result, node_) # add node |
| task_attr = dict( |
| node_name=get_node_name(node_, task_id), |
| name="result_summary", |
| type=task_run.result_summary["observability_type"], |
| # 0.0.3 -> 3 |
| schema_version=int( |
| task_run.result_summary["observability_schema_version"].split(".")[-1] |
| ), |
| value=task_run.result_summary["observability_value"], |
| attribute_role="result_summary", |
| ) |
| else: |
| task_run.status = Status.FAILURE |
| if isinstance(error, dq_base.DataValidationError): |
| task_run.error = runs.serialize_data_quality_error(error) |
| else: |
| task_run.error = traceback.format_exception(type(error), error, error.__traceback__) |
| task_attr = dict( |
| node_name=get_node_name(node_, task_id), |
| name="stack_trace", |
| type="error", |
| schema_version=1, |
| value={ |
| "stack_trace": task_run.error, |
| }, |
| attribute_role="error", |
| ) |
| tracking_state.update_task(get_node_name(node_, task_id), task_run) |
| task_update = dict( |
| node_template_name=node_.name, |
| node_name=get_node_name(node_, task_id), |
| realized_dependencies=[dep.name for dep in node_.dependencies], |
| status=task_run.status, |
| start_time=task_run.start_time, |
| end_time=task_run.end_time, |
| ) |
| await self.client.update_tasks( |
| self.dw_run_ids[run_id], |
| attributes=[task_attr], |
| task_updates=[task_update], |
| in_samples=[task_run.is_in_sample], |
| ) |
| |
| async def post_graph_execute( |
| self, |
| run_id: str, |
| graph: h_graph.FunctionGraph, |
| success: bool, |
| error: Optional[Exception], |
| results: Optional[Dict[str, Any]], |
| ): |
| logger.debug("post_graph_execute %s", run_id) |
| dw_run_id = self.dw_run_ids[run_id] |
| tracking_state = self.tracking_states[run_id] |
| tracking_state.clock_end(status=Status.SUCCESS if success else Status.FAILURE) |
| finally_block_time = datetime.datetime.utcnow() |
| if tracking_state.status != Status.SUCCESS: |
| # TODO: figure out how to handle crtl+c stuff |
| tracking_state.status = Status.FAILURE |
| # this assumes the task map only has things that have been processed, not |
| # nodes that have yet to be computed. |
| for task_name, task_run in tracking_state.task_map.items(): |
| if task_run.status != Status.SUCCESS: |
| task_run.status = Status.FAILURE |
| task_run.end_time = finally_block_time |
| if task_run.error is None: # we likely aborted it. |
| # Note if we start to do concurrent execution we'll likely |
| # need to adjust this. |
| task_run.error = ["Run was likely aborted."] |
| if task_run.end_time is None and task_run.status == Status.SUCCESS: |
| task_run.end_time = finally_block_time |
| |
| # TODO: only update things that have changed? |
| # self.client.update_tasks( |
| # dag_run_id=dw_run_id, |
| # attributes=driver.extract_attributes_from_tracking_state(tracking_state), |
| # task_updates=driver.extract_task_updates_from_tracking_state(tracking_state, graph), |
| # ) |
| await self.client.log_dag_run_end( |
| dag_run_id=dw_run_id, |
| status=tracking_state.status.value, |
| ) |
| logger.warning( |
| f"\nCaptured execution run. Results can be found at " |
| f"{self.hamilton_ui_url}/dashboard/project/{self.project_id}/runs/{dw_run_id}\n" |
| ) |