blob: 7677ce89c4fd69cc5233e513560445511bad1174 [file] [log] [blame]
import datetime
import hashlib
import inspect
import json
import linecache
import logging
import operator
import os
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from hamilton_sdk.api.clients import UnauthorizedException
from hamilton import base, driver, graph, node
from hamilton.driver import Variable
from hamilton.io import materialization
from hamilton.lifecycle.base import BaseDoNodeExecute
from hamilton.node import Node
try:
import git
except ImportError:
git = None
from hamilton_sdk.api import clients, constants
from hamilton_sdk.api.projecttypes import GitInfo
from hamilton_sdk.tracking.runs import Status, TrackingState, monkey_patch_adapter
logger = logging.getLogger(__name__)
def _hash_module(
module: ModuleType, hash_object: hashlib.sha256, seen_modules: Set[ModuleType]
) -> hashlib.sha256:
"""Generate a hash of the specified module and its imports.
It will recursively hash the contents of the modules and their imports, and only does so
if the import is from the same package. This is to avoid hashing the entire python
environment...
:param module: the python module to hash and then crawl.
:param hash_object: the object to update.
:param seen_modules: the python modules we've already hashed.
:return: the updated hash object
"""
# Check if we've already hashed this module
if module in seen_modules:
return hash_object
else:
seen_modules.add(module)
# Update the hash with the module's source code
if hasattr(module, "__file__") and module.__file__ is not None:
with open(module.__file__, "rb") as f:
hash_object.update(f.read())
else:
logger.debug(
"Skipping hash for module %s because it has no __file__ attribute or it is None.",
module,
)
# Loop through the module's attributes
for name, value in inspect.getmembers(module):
# Check if the attribute is a module
if inspect.ismodule(value):
if value.__package__ is None:
logger.info(
f"Skipping hash for module {value.__name__} because it has no __package__ "
f"attribute or it is None. This happens with lazy loaders."
)
continue
# Modules imported in a temporary module have no `__package__` attribute
if module.__package__:
# Check if the module is in the same top level package
if value.__package__ != module.__package__ and not value.__package__.startswith(
module.__package__
):
logger.debug(
f"Skipping hash for module {value.__name__} because it is in a different "
f"package {value.__package__} than {module.__package__}"
)
continue
# Recursively hash the sub-module
hash_object = _hash_module(value, hash_object, seen_modules)
# Return the hash object
return hash_object
def _get_modules_hash(modules: Tuple[ModuleType]) -> str:
"""Generate a hash of the contents of the specified modules.
It recursively hashes the contents of the modules and their imports, and only does so
if the import is from the same package. This is to avoid hashing the entire python
environment...
:param modules: python modules to hash
:return: the hex digest of the hash
"""
# Create a hash object
h = hashlib.sha256()
seen_modules = set()
# Loop through each module name
for module in modules:
# Update the hash with the module's source code
h = _hash_module(module, h, seen_modules)
# Return the hex digest of the hash
return h.hexdigest()
def _derive_version_control_info(module_hash: str) -> GitInfo:
"""Derive the git info for the current project.
Currently, this decides whether we're in a git repository.
This is not going to work for everything, but we'll see what the customers want.
We might end up having to pass this data in...
"""
default = GitInfo(
branch="unknown",
commit_hash=module_hash,
committed=False,
repository="Error: No repository to link to.",
local_repo_base_path=os.getcwd(),
)
if git is None:
return default
try:
repo = git.Repo(".", search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
logger.warning(
"Warning: We are not currently in a git repository. We recommend using that as a "
"way to version the "
"project *if* your hamilton code lives within this repository too. If it does not,"
" then we'll try to "
"version code based on the python modules passed to the Driver. "
"Incase you want to get set up with git quickly you can run:\n "
"git init && git add . && git commit -m 'Initial commit'\n"
"Still have questions? Reach out to stefan @ dagworks.io, elijah @ dagworks.io "
"and we'll try to help you as soon as possible."
)
return default
if "COLAB_RELEASE_TAG" in os.environ:
logger.warning(
"We currently do not support logging version information inside a google"
"colab notebook. This is something we are planning to do. "
"If you have any questions, please reach out to support@dagworks.io"
"and we'll try to help you as soon as possible."
)
return default
commit = repo.head.commit
try:
repo_url = repo.remote().url
except ValueError:
# TODO: change this to point to our docs on what to do.
repo_url = "Error: No repository to link to."
try:
branch_name = repo.active_branch.name
except TypeError:
branch_name = "unknown" # detached head
logger.warning(
"Warning: we are unable to determine the branch name. "
"This is likely because you are in a detached head state. "
"If you are in a detached head state, you can check out a "
"branch by running `git checkout -b <branch_name>`. "
"If you intend to be (if you are using some sort of CI"
"system that checks out a detached head) then you can ignore this."
)
return GitInfo(
branch=branch_name,
commit_hash=commit.hexsha,
committed=not repo.is_dirty(),
repository=repo_url,
local_repo_base_path=repo.working_dir,
)
def filter_json_dict_to_serializable(
dict_to_filter: Dict[str, Any], curr_result: Dict[str, Any] = None
):
if curr_result is None:
curr_result = {}
if dict_to_filter is None:
dict_to_filter = {}
for key, value in dict_to_filter.items():
try:
json.dumps(value)
curr_result[key] = value
except TypeError:
if isinstance(value, dict):
new_result = {}
filter_json_dict_to_serializable(value, new_result)
curr_result[key] = new_result
else:
curr_result[key] = str(value)
return curr_result
def validate_tags(tags: Any):
"""Validates that tags are a dictionary of strings to strings.
:param tags: Tags to validate
:raises ValueError: If tags are not a dictionary of strings to strings
"""
if not isinstance(tags, dict):
raise ValueError(f"Tags must be a dictionary, but got {tags}")
for key, value in tags.items():
if not isinstance(key, str):
raise ValueError(f"Tag keys must be strings, but got {key}")
if not isinstance(value, str):
raise ValueError(f"Tag values must be strings, but got {value}")
def safe_len(x):
return len(x) if x is not None else 0
# Placeholder so we can monkey-patch later (see monkey_patch_adapter)
class DefaultExecutionMethod(BaseDoNodeExecute):
def do_node_execute(
self,
*,
run_id: str,
node_: node.Node,
kwargs: Dict[str, Any],
task_id: Optional[str] = None,
) -> Any:
return node_(**kwargs)
class Driver(driver.Driver):
def __init__(
self,
config: Dict[str, Any],
*modules: ModuleType,
project_id: int,
api_key: str,
username: str,
dag_name: str,
tags: Dict[str, str] = None,
client_factory: Callable[
[str, str, str], clients.HamiltonClient
] = clients.BasicSynchronousHamiltonClient,
adapter: base.HamiltonGraphAdapter = None,
dagworks_api_url=os.environ.get("DAGWORKS_API_URL", constants.HAMILTON_API_URL),
dagworks_ui_url=os.environ.get("DAGWORKS_UI_URL", constants.HAMILTON_UI_URL),
):
"""Instantiates a DAGWorks driver. This:
1. Requires a project to exist. Create one via https://app.dagworks.io/dashboard/projects.
2. Sends over the shape of the DAG.
3. Sets up execute() run-tracking.
:param config: Configuration to use, same as standard Hamilton driver.
:param modules: Modules to use, same as standard Hamilton driver.
:param project_id: Identifier for the project to use to store this DAG under.
:param api_key: API key to use for authentication. Remember not to save this in plaintext!
:param username: email address to use for authentication.
:param dag_name: name for this DAG. You will use this for top level curation of DAGs
within a project.
:param tags: Optional key value string pairs to help identify and curate this instance of
the DAG and subsequent execution runs. E.g. {"environment": "production"}.
Currently all .execute() runs will be tagged with these.
:param client_factory: Optional. Advanced use. Factory to use to create the underlying
client.
:param adapter: Optional. Adapter to use, same as standard Hamilton driver.
:param dagworks_api_url: Optional. URL to use for the DAGWorks API.
:param dagworks_ui_url: Optional. URL to use for the DAGWorks UI.
"""
if adapter is None:
adapter = base.SimplePythonGraphAdapter(result_builder=base.DictResult())
super(Driver, self).__init__(config, *modules, adapter=adapter)
self.config = config
self.project = project_id
self.api_key = api_key
self.username = username
# TODO -- figure out how to pass any additional configuration to the client if needed
self.client = client_factory(api_key, username, dagworks_api_url)
self.module_hash = _get_modules_hash(modules)
self.vcs_info = _derive_version_control_info(self.module_hash)
self.initialized = False
self.modules = modules
self.project_version = None
self.run_tags = tags if tags is not None else {}
validate_tags(self.run_tags)
self.dag_name = dag_name
self.dagworks_ui_url = dagworks_ui_url
# reassign the graph executor with all the information we have
self.graph_executor = DAGWorksGraphExecutor(
self.graph_executor,
self.client,
self.run_tags,
self.dagworks_ui_url,
self.project,
self.vcs_info.local_repo_base_path,
self.vcs_info,
self.dag_name,
list(self.graph_modules),
self.graph,
)
self.initialize()
def set_name(self, new_name: str):
"""Sets a name for the driver. This allows you to force a change in the name/version of the
DAG so the next run logs a new one.
:param new_name:
"""
self.dag_name = new_name
self.graph_executor.dag_name = new_name
def initialize(self):
"""Initializes the driver. This:
1. Validates authentication
2. Creates a project if it does not already exist
3. Sets initialization as true
Note this is idempotent -- it can be called by the user to test, but will get called when
the driver runs.
"""
logger.debug("Validating authentication against DAGWorks API...")
self.client.validate_auth()
logger.debug("Authentication successful!")
logger.debug(f"Ensuring project {self.project} exists...")
try:
exists = self.client.project_exists(self.project)
except UnauthorizedException:
logger.exception(
f"Authentication failed. Please check your credentials and try again. "
f"Username: {self.username}, API key: {self.api_key[0:6]}..."
)
raise
if not exists:
logger.error(
f"Project {self.project} does not exist. Please create it first in the UI! "
f"You can do so at https://www.app.dagworks.io/dashboard/projects"
)
self.graph_executor.initialize()
self.initialized = True
def execute(
self,
final_vars: List[Union[str, Callable]],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
) -> Any:
logger.warning(
f"\nCapturing execution run. All runs for project can be found at "
f"{self.dagworks_ui_url}/dashboard/project/{self.project}/runs"
)
return super(Driver, self).execute(final_vars, overrides, display_graph, inputs)
def raw_execute(
self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
) -> Dict[str, Any]:
return super(Driver, self).raw_execute(
final_vars, overrides, display_graph, inputs, _fn_graph
)
def materialize(
self,
*materializers: materialization.MaterializerFactory,
additional_vars: List[Union[str, Callable, Variable]] = None,
overrides: Dict[str, Any] = None,
inputs: Dict[str, Any] = None,
) -> Tuple[Any, Dict[str, Any]]:
return super(Driver, self).materialize(
*materializers, additional_vars=additional_vars, overrides=overrides, inputs=inputs
)
def _get_fully_qualified_function_path(fn: Callable) -> str:
"""Gets the fully qualified path of a function.
:param fn: Function to get the path of
:return: Fully qualified path of the function
"""
module = inspect.getmodule(fn)
fn_name = fn.__name__
if module is not None:
fn_name = f"{module.__name__}.{fn_name}"
return fn_name
def hash_dag(dag: graph.FunctionGraph) -> str:
"""Hashes a DAG.
:param dag: DAG to hash
:return: Hash of the DAG
"""
digest = hashlib.sha256()
hashing_node_fields = {
"name": str,
"type": str, # Hash it to a string for the type
"node_role": str,
"dependencies": lambda x: " ".join(
[item.name for item in sorted(x, key=operator.attrgetter("name"))]
), # Sort dependencies
}
for node_ in sorted(dag.nodes.values(), key=operator.attrgetter("name")):
for field, serializer in hashing_node_fields.items():
digest.update(json.dumps(serializer(getattr(node_, field))).encode())
if node_.originating_functions is None:
continue
for fn in node_.originating_functions:
fn_name = _get_fully_qualified_function_path(fn)
digest.update(fn_name.encode())
return digest.hexdigest()
def hash_dag_modules(dag: graph.FunctionGraph, modules: List[ModuleType]):
modules_by_path = {}
for module in modules:
if hasattr(module, "__file__") and module.__file__ is not None:
try:
modules_by_path[module.__file__] = inspect.getsource(module)
except OSError:
logger.warning(
f"Skipping hashing of module {module.__name__} because we could not read the source code."
)
digest = hashlib.sha256()
for node_ in sorted(dag.nodes.values(), key=operator.attrgetter("name")):
if node_.originating_functions is None:
continue
for fn in node_.originating_functions:
module = inspect.getmodule(fn)
if hasattr(module, "__file__") and module.__file__ is not None:
modules_by_path[module.__file__] = module
for module_path, module_contents in sorted(modules_by_path.items()):
# if the filename is tmpXXXXXXXX.py assume it's a temporary file and skip hashing the name
# this could be in a jupyter context in which case this will cause different code
# versions when in fact there are none.
file_name = os.path.basename(module_path)
if file_name.startswith("tmp") and len(file_name) == 14:
pass
else:
digest.update(module_path.encode())
digest.update(_get_modules_hash((module_contents,)).encode())
return digest.hexdigest()
def _convert_node_dependencies(node: Node) -> dict:
dependencies = []
dependency_specs = []
dependency_specs_type = "python_type"
dependency_specs_schema_version = 1
for node_ in node.dependencies:
dependencies.append(node_.name)
dependency_specs.append({"type_name": str(node_.type)})
return {
"dependencies": dependencies,
"dependency_specs": dependency_specs,
"dependency_specs_type": dependency_specs_type,
"dependency_specs_schema_version": dependency_specs_schema_version,
}
def _convert_classifications(node_: Node) -> List[str]:
out = []
if (
node_.tags.get("hamilton.data_loader")
and node_.tags.get("hamilton.data_loader.has_metadata") is not False
):
out.append("data_loader")
elif node_.tags.get("hamilton.data_saver"):
out.append("data_saver")
elif node_.user_defined:
out.append("input")
else:
out.append("transform")
return out
def _extract_node_templates_from_function_graph(fn_graph: graph.FunctionGraph) -> List[dict]:
"""Converts a function graph to a list of nodes that the DAGWorks graph can understand.
@param fn: Function graph to convert
@return: A list of node objects
"""
node_templates = []
for node_ in fn_graph.nodes.values():
code_artifact_pointers = (
[]
if (node_.originating_functions is None or len(node_.originating_functions) == 0)
else [_get_fully_qualified_function_path(fn) for fn in node_.originating_functions]
)
node_templates.append(
dict(
name=node_.name,
output={"type_name": str(node_.type)},
output_type="python_type",
output_schema_version=1, # TODO -- merge this with _convert_node_dependencies
documentation=node_.documentation,
tags=node_.tags, # TODO -- ensure serializable
classifications=_convert_classifications(node_), # TODO -- manage classifications
code_artifact_pointers=(
code_artifact_pointers
if node_.originating_functions is None or len(node_.originating_functions) == 0
else [
_get_fully_qualified_function_path(fn) for fn in node_.originating_functions
]
),
**_convert_node_dependencies(node_),
)
)
return node_templates
def _derive_url(vcs_info: GitInfo, path: str, line: int) -> str:
"""Derives a URL from a VCS info, a path, and a line number.
@param vcs_info: VCS info
@param path: Path
@param line: Line number
@return: A URL
"""
if vcs_info.repository == "Error: No repository to link to.":
return "Error: No repository to link to."
if vcs_info.repository.endswith(".git"):
repo_url = vcs_info.repository[:-4]
else:
repo_url = vcs_info.repository
return f"{repo_url}/blob/{vcs_info.commit_hash}/{path}#L{line}"
def getsourcelines(object, stop: Callable = None) -> tuple:
"""Adding this here incase we want to pull decorator code too.
This is modification of the underlying inspect function.
Return a list of source lines and starting line number for an object.
The argument may be a module, class, method, function, traceback, frame,
or code object. The source code is returned as a list of the lines
corresponding to the object and the line number indicates where in the
original source file the first line of code was found. An OSError is
raised if the source code cannot be retrieved.
This will return the decorator code, or the underlying wrapped function code.
"""
object = inspect.unwrap(object, stop=stop)
lines, lnum = inspect.findsource(object)
if inspect.istraceback(object):
object = object.tb_frame
# for module or frame that corresponds to module, return all source lines
if inspect.ismodule(object) or (
inspect.isframe(object) and object.f_code.co_name == "<module>"
):
return lines, 0
else:
return inspect.getblock(lines[lnum:]), lnum + 1
def extract_code_artifacts_from_function_graph(
fn_graph: graph.FunctionGraph, vcs_info: GitInfo, repo_base_path: str
) -> List[dict]:
"""Converts a function graph to a list of code artifacts that the function graph uses.
@param fn_graph: Function graph to convert.
@return: A list of node objects.
"""
seen = set()
out = []
for node_ in fn_graph.nodes.values():
originating_functions = node_.originating_functions
if originating_functions is None:
continue
for fn in originating_functions:
fn_name = _get_fully_qualified_function_path(fn)
if fn_name not in seen:
seen.add(fn_name)
# need to handle decorators -- they will return the wrong sourcefile.
unwrapped_fn = inspect.unwrap(fn)
if unwrapped_fn != fn:
# TODO: pull decorator stuff too
source_file = inspect.getsourcefile(unwrapped_fn)
else:
source_file = inspect.getsourcefile(fn)
if source_file is not None:
path = os.path.relpath(source_file, repo_base_path)
else:
path = ""
try:
source_lines = inspect.getsourcelines(fn)
out.append(
dict(
name=fn_name,
type="p_function",
path=path,
start=inspect.getsourcelines(fn)[1] - 1,
end=inspect.getsourcelines(fn)[1] - 1 + len(source_lines[0]),
url=_derive_url(vcs_info, path, source_lines[1]),
)
)
except OSError:
# This is an error state where somehow we don't have
# source code.
out.append(
dict(
name=fn_name,
type="p_function",
path=path,
start=0,
end=0,
url=_derive_url(vcs_info, path, 0),
)
)
return out
def extract_attributes_from_tracking_state(tracking_state: TrackingState) -> List[dict]:
"""Extracts attributes from tracking state. We'll likely rewrite this shortly --
this is just to bridge so we can get the client out. Next, we'll want it putting
stuff on a queue, and then sends it over in batches. The tracking state is a hack
and we'll get rid of it.
@param tracking_state: Tracking state
@return: A list of attributes
"""
# This just bridges some of the old code so we can move quickly
# TODO -- fix so we don't have to do that
out = []
dag_run = tracking_state.get()
for task in dag_run.tasks:
if task.error is not None:
out.append(
dict(
node_name=task.node_name,
name="stack_trace",
type="error",
schema_version=1,
value={
"stack_trace": task.error,
},
attribute_role="error",
)
)
if task.result_summary is not None:
out.append(
dict(
node_name=task.node_name,
name="result_summary",
type=task.result_summary["observability_type"],
# 0.0.3 -> 3
schema_version=int(
task.result_summary["observability_schema_version"].split(".")[-1]
),
value=task.result_summary["observability_value"],
attribute_role="result_summary",
)
)
return out
def extract_task_updates_from_tracking_state(
tracking_state: TrackingState, fg: graph.FunctionGraph
) -> List[dict]:
"""Extracts task updates from tracking state. We'll likely rewrite this shortly --
this is a hack (using the tracking state) -- we'll want to extract these as we go along,
and we'll want it putting stuff on a queue, and then sends it over in batches.
@param tracking_state:
@return:
"""
# TODO -- do the tracking state in a cleaner way
# This is left over from the old way we were doing things
dag_run = tracking_state.get()
out = []
for task in dag_run.tasks:
node_ = fg.nodes[task.node_name]
out.append(
dict(
node_template_name=task.node_name,
node_name=task.node_name,
realized_dependencies=[dep.name for dep in node_.dependencies],
status=task.status,
start_time=task.start_time,
end_time=task.end_time,
)
)
return out
def _slurp_code(fg: graph.FunctionGraph, repo_base: str) -> List[dict]:
"""Get the source code from modules. Returns a list with a dictionary for each module.
The `path` attribute needs to match the `path` of code artifacts generated by
`extract_code_artifacts_from_function_graph()`
"""
modules = set()
for node_ in fg.nodes.values():
originating_functions = node_.originating_functions
if originating_functions is None:
continue
for fn in originating_functions:
module = inspect.getmodule(fn)
modules.add(module)
out = []
for module in modules:
if hasattr(module, "__file__") and module.__file__ is not None:
module_path = os.path.relpath(module.__file__, repo_base)
with open(module.__file__, "r") as f:
out.append({"path": module_path, "contents": f.read()})
# for temporary modules registed via `module_from_source`
else:
# get source code from the linecache; returns a tuple (size, mtime, lines, fullname)
source_lines = linecache.cache[module.__name__][2]
source = "".join(source_lines)
# the path won't have a `.py` suffix to match `extract_code_artifacts_from_function_grap()`
module_path = os.path.relpath(module.__name__, repo_base)
out.append({"path": module_path, "contents": source})
return out
class DAGWorksGraphExecutor(driver.GraphExecutor):
def __init__(
self,
wrapping_executor: driver.GraphExecutor,
client: clients.HamiltonClient,
run_tags: Dict[str, str],
dagworks_ui_url: str,
project_id: int,
repo_base: str,
vcs_info: GitInfo,
dag_name: str,
graph_modules: List[ModuleType],
initial_graph: graph.FunctionGraph,
):
self.executor = wrapping_executor
self.client = client
self.run_tags = run_tags
self.dagworks_ui_url = dagworks_ui_url
self.project_id = project_id
self.repo_base = repo_base
self.vcs_info = vcs_info
self.dag_name = dag_name
self.graph_modules = graph_modules
self.dag_template_id_cache = {}
self.initial_graph = initial_graph
def initialize(self):
self._register_or_query(self.initial_graph)
def _register_or_query(self, fg: graph.FunctionGraph) -> int:
"""Creates a DAG template if the funtion graph doesn't exist. Otherwise
we use the cache we see. Note that if this already
@param fg:
@return:
"""
# Quick way to bypass anything complicated
fg_id = id(fg)
if fg_id in self.dag_template_id_cache:
return self.dag_template_id_cache[fg_id]
dag_hash = hash_dag(fg)
code_hash = hash_dag_modules(fg, self.graph_modules)
dag_template_id = self.client.register_dag_template_if_not_exists(
project_id=self.project_id,
dag_hash=dag_hash,
code_hash=code_hash,
name=self.dag_name,
nodes=_extract_node_templates_from_function_graph(fg),
code_artifacts=extract_code_artifacts_from_function_graph(
fg, self.vcs_info, self.repo_base
),
config=fg.config,
tags=self.run_tags,
code=_slurp_code(fg, self.repo_base),
vcs_info=self.vcs_info,
)
self.dag_template_id_cache[fg_id] = dag_template_id
return dag_template_id
def execute(
self,
fg: graph.FunctionGraph,
final_vars: List[Union[str, Callable, Variable]],
overrides: Dict[str, Any],
inputs: Dict[str, Any],
run_id: str,
) -> Dict[str, Any]:
"""Executes a graph in a blocking function.
:param fg: Graph to execute
:param final_vars: Variables we want
:param overrides: Overrides --- these short-circuit computation
:param inputs: Inputs to the Graph.
:return: The output of the final variables, in dictionary form.
"""
logger.info(f"Logging code version for DAG {self.dag_name}...")
dag_template_id = self._register_or_query(fg)
tracking_state = TrackingState(run_id)
with monkey_patch_adapter(fg.adapter, tracking_state):
tracking_state.clock_start()
dag_run_id = self.client.create_and_start_dag_run(
dag_template_id=dag_template_id,
tags=self.run_tags,
inputs=inputs if inputs is not None else {},
outputs=final_vars,
)
try:
out = self.executor.execute(fg, final_vars, overrides, inputs, run_id=dag_run_id)
tracking_state.clock_end(status=Status.SUCCESS)
return out
except Exception as e:
tracking_state.clock_end(status=Status.FAILURE)
raise e
finally:
finally_block_time = datetime.datetime.utcnow()
if tracking_state.status != Status.SUCCESS:
tracking_state.status = Status.FAILURE
# this assumes the task map only has things that have been processed, not
# nodes that have yet to be computed.
for task_name, task_run in tracking_state.task_map.items():
if task_run.status != Status.SUCCESS:
task_run.status = Status.FAILURE
task_run.end_time = finally_block_time
if task_run.error is None: # we likely aborted it.
# Note if we start to do concurrent execution we'll likely
# need to adjust this.
task_run.error = ["Run was likely aborted."]
if task_run.end_time is None and task_run.status == Status.SUCCESS:
task_run.end_time = finally_block_time
self.client.update_tasks(
dag_run_id=dag_run_id,
attributes=extract_attributes_from_tracking_state(tracking_state),
task_updates=extract_task_updates_from_tracking_state(tracking_state, fg),
)
self.client.log_dag_run_end(
dag_run_id=dag_run_id,
status=tracking_state.status.value,
)
logger.warning(
f"\nCaptured execution run. Results can be found at "
f"{self.dagworks_ui_url}/dashboard/project/{self.project_id}/runs/{dag_run_id}\n"
)
def validate(self, nodes_to_execute: List[node.Node]):
pass