blob: f4f5de22e7ece75c8b2b513e273baebbf09632a2 [file] [log] [blame]
import collections
import logging
from typing import List, Optional
from common import django_utils
from ninja import Router
from ninja.errors import HttpError
from ninja.params import Query
from trackingserver_base.permissions.base import permission
from trackingserver_base.permissions.permissions import (
user_can_get_dag_runs,
user_can_get_latest_dag_runs,
user_can_get_project_by_id,
user_can_write_to_dag_run,
user_can_write_to_dag_template,
)
from trackingserver_run_tracking.models import DAGRun, NodeRun, NodeRunAttribute
from trackingserver_run_tracking.schema import (
CatalogZoomResponse,
DAGRunIn,
DAGRunOut,
DAGRunOutWithData,
DagRunsBulkRequest,
DAGRunUpdate,
NodeRunAttributeOut,
NodeRunIn,
NodeRunOutWithAttributes,
NodeRunOutWithExtraData,
)
from trackingserver_template.models import NodeTemplate
from trackingserver_template.schema import CodeArtifactOut, NodeTemplateOut
logger = logging.getLogger(__name__)
router = Router(tags=["run_tracking"])
@router.post("/v1/dag_runs", response=DAGRunOut, tags=["run_tracking"])
@permission(user_can_write_to_dag_template)
async def create_dag_run(request, dag_template_id: int, dag_run: DAGRunIn) -> DAGRunOut:
"""Creates a DAG run, no data in it yet.
@param dag_template_id:
@param dag_run:
@return:
"""
user, teams = request.auth
logger.info(f"Creating DAG run for dag template: {dag_template_id} for user: {user.email}")
dag_run_created = await DAGRun.objects.acreate(
**dag_run.dict(), dag_template_id=dag_template_id, launched_by_id=user.id
)
logger.info(f"Created DAG run for dag template: {dag_template_id}")
return DAGRunOut.from_orm(dag_run_created)
@router.get("/v1/dag_runs/latest/", response=List[DAGRunOut], tags=["run_tracking"])
@permission(user_can_get_latest_dag_runs)
async def get_latest_dag_runs(
request,
project_id: int = None,
dag_template_id: int = None,
limit: int = 100,
offset: int = 0,
) -> List[DAGRunOut]:
"""Gets a list of DAG runs. This accepts one (and only one of) the following:
- project_id
- dag_template_id
@param project_id:
@param dag_template_id:
@param limit:
@param offset:
@return list of DAG runs:
"""
user, teams = request.auth
logger.info(
f"Getting DAG runs for user: {user.email}, "
f"project: {project_id}"
f"dag template: {dag_template_id}"
)
key_kwargs = {
"dag_template__project_id": project_id,
"dag_template_id": dag_template_id,
}
key_kwargs = {k: v for k, v in key_kwargs.items() if v is not None}
if len(key_kwargs) > 1:
raise HttpError(
422,
f"Can only specify one of project_id/dag_template_id, "
f"got: project_id={project_id}, dag_template_id={dag_template_id}",
)
dag_runs = await django_utils.amap(
DAGRunOut.create_with_username,
DAGRun.objects.filter(**key_kwargs, **{"dag_template__is_active": True})
.order_by("-created_at")[offset : limit + offset]
.prefetch_related("launched_by")
.all(),
)
logger.info(f"Got {len(dag_runs)} DAG runs for user: {user.email}")
return dag_runs
@router.get("/v1/dag_runs/{dag_run_ids}", response=List[DAGRunOutWithData], tags=["run_tracking"])
@permission(user_can_get_dag_runs)
async def get_dag_runs(
request,
dag_run_ids: str,
attrs: List[str] = Query(default=None, alias="attr"),
) -> List[DAGRunOutWithData]:
"""Queries a DAG run with all the data.
Note that you must pass an attribute filter, indicating
the attributes about which you care.
@param attrs:
@param dag_run_id:
@return:
"""
user, teams = request.auth
if attrs is None:
attrs = []
logger.info(f"Getting DAG run(s): {dag_run_ids} for user: {user.email}")
dag_run_ids_parsed = [int(dag_run_id) for dag_run_id in dag_run_ids.split(",")]
all_dag_run_runs = [
item async for item in DAGRun.objects.filter(id__in=dag_run_ids_parsed).all()
]
all_nodes = await django_utils.alist(
NodeRun.objects.filter(dag_run_id__in=dag_run_ids_parsed).all()
)
all_attributes = await django_utils.amap(
NodeRunAttributeOut.from_orm,
NodeRunAttribute.objects.filter(dag_run_id__in=dag_run_ids_parsed, type__in=attrs).all(),
)
all_results = []
for dag_run_id in dag_run_ids_parsed:
attributes = [attribute for attribute in all_attributes if attribute.dag_run == dag_run_id]
nodes = [node for node in all_nodes if node.dag_run_id == dag_run_id]
dag_run_candidates = [dag_run for dag_run in all_dag_run_runs if dag_run.id == dag_run_id]
if len(dag_run_candidates) == 0:
logger.warning(f"DAG run with ID {dag_run_id} does not exist.")
raise HttpError(404, f"DAG run with ID {dag_run_id} does not exist.")
(dag_run,) = dag_run_candidates
attributes_grouped_by_node = collections.defaultdict(list)
for attribute in attributes:
attributes_grouped_by_node[attribute.node_name].append(attribute)
out = DAGRunOutWithData.from_data(dag_run, [])
for node in nodes:
out.node_runs.append(
NodeRunOutWithAttributes.from_data(
node,
attributes=attributes_grouped_by_node[node.node_name],
)
)
all_results.append(out)
logger.info(f"Got DAG run: {dag_run_id}")
return all_results
@router.put("/v1/dag_runs/{dag_run_id}/", response=DAGRunOut, tags=["run_tracking"])
@permission(user_can_write_to_dag_run)
async def update_dag_run(request, dag_run_id: int, dag_run: DAGRunUpdate) -> DAGRunOut:
"""Creates a DAG run, no data in it yet.
@param dag_template_id:
@param dag_run:
@return:
"""
user, teams = request.auth
logger.info(f"Updating DAG run for dag template: {dag_run_id} for user: {user.email}")
try:
dag_run_in_db = await DAGRun.objects.aget(id=dag_run_id)
except DAGRun.DoesNotExist:
raise HttpError(404, f"DAG run with ID {dag_run_id} does not exist.")
for attr, value in dag_run.dict(exclude_unset=True).items():
if attr == "upsert_tags":
upsert_tags = {} if dag_run.upsert_tags is None else dag_run.upsert_tags
dag_run_in_db.tags = {**dag_run_in_db.tags, **upsert_tags}
setattr(dag_run_in_db, attr, value)
await dag_run_in_db.asave()
logger.info(f"Updated DAG run for dag template: {dag_run_id}")
# TODO -- unify this with the above, we're doing way too many db queries...
# We should just be doing one, but as django doesn't appear to support postgres's
# on_conflict for json fields and we want to upsert, we'll probably have to do 2
saved = await DAGRun.objects.aget(id=dag_run_id)
return DAGRunOut.from_orm(saved)
def process_task_updates(node_runs: List[NodeRunIn], dag_run_id: int) -> List[NodeRun]:
"""Processes task updates, returning a list of NodeRuns to save.
TODO -- squash any task updates on the same task, resolving conflicts
(E.G. if one is SUCCESS and one is FAILURE).
@param node_run_in:
@return:
"""
node_updates = {} # Map of node_name to NodeRun
# We update this as we go along
for node_run in node_runs:
data = node_run.dict(exclude_unset=True)
data["dag_run_id"] = dag_run_id
name = node_run.node_name
if name in node_updates:
# TODO -- resolve conflicts in a sane way
# See TODO in README
for key, value in data.items():
setattr(node_updates[name], key, value)
else:
node_updates[name] = NodeRun(**data)
return list(node_updates.values())
@router.put("/v1/dag_runs_bulk", response=None, tags=["run_tracking"])
@permission(user_can_write_to_dag_run)
async def bulk_log(
request,
bulk_log_data: DagRunsBulkRequest,
dag_run_id: int,
):
"""Performs a bulk logging operation for a DAG run.
Note that these do not need to know the IDs of anything besides the DAG Run,
and we will likely have it use a UUID for the DAG run ID (the one the client creates).
This will do one of the following:
1. Assert that the DAGRun exists
2. Gather all the associated NodeTemplates by referencing the node names
3. Perform an in-memory join with relevant node_runs to create all the NodeRun objects
4. Bulk create all the NodeRun objects
5. Bulk create all the NodeRunAttribute objects
Note that this might not work if an attrbute is logged *before* a node.
In that case, we'll likely do the following:
1. Change foreign key from noderun to be a string with an index on it
2. This will allow us to bulk create all the NodeRunAttributes
3. We would still join node run to template
Note that we would have to add DAGRun as a foreign key here.
Then the query patterns we want:
1. Query all nodes/attributes for a given DAG run
2. Query all attributes for a node
3. Query all iterations of a previous attribute for that node
Should be easy to do.
Will likely do the second one -- it simplifies it and doesn't require
us to do a lot of ugly joins to know what to do.
@param dag_run_id:
@param dag_run:
@param attributes:
@param task_updates:
@return:
"""
node_run_attributes = bulk_log_data.attributes
node_run_updates = bulk_log_data.task_updates
try:
dag_run = await DAGRun.objects.aget(id=dag_run_id)
except DAGRun.DoesNotExist:
raise HttpError(404, f"DAG run with ID {dag_run_id} does not exist.")
task_updates_to_save = process_task_updates(node_run_updates, dag_run_id=dag_run.id)
# TODO -- determine if we can do this in one pass
# E.G. if it allows us to set, say, end_time,
# Then we can set *just* run_status, and if end_time wasn't included in the constructor
# This is optimistic -- we hope that we can do this in one pass, allowing us to just process a list of updates
# If not, we have two strategies:
# > Query then merge
# > Break up by field (this is simple -- we have end_time/run_status which are the only ones that should be updated...
# Note that, given that we have run_status = RUNNING -> SUCCESS, FAILED, we konw the status transitions are safe,
# So maybe it just works to update everything in one pass?
# Also maybe we can use this? https://github.com/SectorLabs/django-postgres-extra
logger.info(f"Updating {len(task_updates_to_save)} task updates for dag run: {dag_run_id}")
await NodeRun.objects.abulk_create(
task_updates_to_save,
update_conflicts=True,
update_fields=[
"end_time",
"status",
],
unique_fields=["dag_run_id", "node_name"],
)
logger.info(f"Updated {len(task_updates_to_save)} task updates for dag run: {dag_run_id}")
node_attributes_to_create = [
NodeRunAttribute(**attribute.dict(), dag_run_id=dag_run.id)
for attribute in node_run_attributes
]
logger.info(
f"Creating {len(node_attributes_to_create)} node attributes for dag run: {dag_run_id}"
)
await NodeRunAttribute.objects.abulk_create(node_attributes_to_create)
logger.info(
f"Created {len(node_attributes_to_create)} node attributes for dag run: {dag_run_id}"
)
@router.get(
"/v1/node_runs/by_run_ids/{dag_run_ids}",
response=List[Optional[NodeRunOutWithAttributes]],
tags=["run_tracking", "cross-runs", "FE-only", "node_tracking"],
)
@permission(user_can_get_dag_runs)
async def get_node_run_for_dags(
request,
dag_run_ids: str,
node_name: str,
) -> List[Optional[NodeRunOutWithAttributes]]:
"""Gives a specific set of node runs for a certain set of attribute types.
This enables the DAG viz in the UI to compare nodes with attributes across runs.
Gives back all node attributes.
@param request: Django request
@param dag_run_ids: Comma-separated list of DAG run IDs
@param node_name: name of the node
@return: List of node runs
"""
user, teams = request.auth
logger.info(
f"Getting node runs for node: {node_name} for dag runs: {dag_run_ids} for all attributes for user: {user.email}"
)
dag_run_ids_parsed = [int(dag_run_id) for dag_run_id in dag_run_ids.split(",")]
all_attributes = await django_utils.amap(
NodeRunAttributeOut.from_orm,
NodeRunAttribute.objects.filter(
dag_run_id__in=dag_run_ids_parsed, node_name=node_name
).all(),
)
all_nodes = await django_utils.alist(
NodeRun.objects.filter(dag_run_id__in=dag_run_ids_parsed, node_name=node_name).all()
)
out = []
for dag_run_id in dag_run_ids_parsed:
applicable_nodes = [node for node in all_nodes if node.dag_run_id == dag_run_id]
if len(applicable_nodes) == 0:
out.append(None)
continue
(node,) = applicable_nodes
out.append(
NodeRunOutWithAttributes.from_data(
node,
[attribute for attribute in all_attributes if attribute.dag_run == dag_run_id],
),
)
num_attributes = sum([len(item.attributes) if item else 0 for item in out])
logger.info(
f"Got {len([item for item in out if item])} nodes and {num_attributes} attributes for node: {node_name} "
f"for dag runs: {dag_run_ids} for user: {user.email}"
)
return out
@router.get(
"/v1/node_runs/{node_template_name}",
response=CatalogZoomResponse,
tags=["run_tracking", "cross-runs", "FE-only"],
)
@permission(user_can_get_project_by_id)
async def get_latest_template_runs(
request,
node_template_name: str,
project_id: int,
limit: int = 100,
offset: int = 0,
) -> CatalogZoomResponse:
"""Gets the latest runs for a given node template name.
Specifically for expanding on the catalog view.
@param node_template_name:
@param project_id:
@param limit:
@param offset:
@return:
"""
user, teams = request.auth
logger.info(
f"Getting latest runs for node template: {node_template_name} for user: {user.email} for project: {project_id}"
)
node_runs = await django_utils.alist(
NodeRun.objects.filter(
node_template_name=node_template_name,
dag_run__dag_template__project_id=project_id,
dag_run__dag_template__is_active=True,
)
.prefetch_related("dag_run__dag_template", "dag_run")
.order_by("-created_at")[offset : limit + offset]
.all()
)
all_dag_versions = [item.dag_run.dag_template_id for item in node_runs]
# Breaking it into two cause our data model isn't perfect yet
# consider making the foreign key to node template?
node_templates = await django_utils.alist(
NodeTemplate.objects.filter(dag_template_id__in=all_dag_versions, name=node_template_name)
.prefetch_related("code_artifacts")
.all()
)
logger.info(
f"Got latest runs for node template: {node_template_name} for user: {user.email} for project: {project_id}"
f"got {len(node_runs)} node runs and {len(node_templates)} node templates."
)
code_artifacts = list(
{
code_artifact
for node_template in node_templates
for code_artifact in node_template.code_artifacts.all()
}
)
return CatalogZoomResponse(
node_runs=[
NodeRunOutWithExtraData.from_orm(node_run, dag_template_id=dag_template_id)
for (node_run, dag_template_id) in zip(node_runs, all_dag_versions)
],
node_templates=[
NodeTemplateOut.from_orm(node_template) for node_template in node_templates
],
code_artifacts=[
CodeArtifactOut.from_orm(code_artifact)
for code_artifact in code_artifacts
if code_artifact is not None
],
)
# return node_runs
# @router.get(
# "/v1/node_run_attributes",
# response=DAGRunOut,
# tags=["run_tracking"]
# )
# async def historical_node_attributes(
# node_name: str,
# attribute_names: List[str],
# limit: int = 1000,
# offset: int = 0
# ) -> List[NodeRunAttributeOut]:
# """Queries prior attributes for a specific node.
# This is an historical query that can query all previous
# ones to give, say, a profile/histogram.
#
# @param node_name:
# @param attribute_name:
# @param limit:
# @param offset:
# @return:
# """
#
# logger.info(f"Getting {len(attribute_names)} attributes for node: {node_name}")
# attributes = await django_utils.amap(
# NodeRunAttributeOut.from_orm,
# NodeRunAttribute.objects.filter(
# node_name=node_name,
# name_in=attribute_names).order_by("-created_at")[offset:limit + offset].all()
# )
# logger.info(f"Got {len(attribute_names)} attributes for node: {node_name}")
# return attributes