blob: aea7d221f6c95c189d42e5003b04f58b5d9491b8 [file] [log] [blame]
import asyncio
import logging
import time
from typing import List, Optional
from common.django_utils import alist
from django.db.models import F, Window
from django.db.models.functions import RowNumber
from ninja import Router
from ninja.errors import HttpError
from trackingserver_base import blob_storage
from trackingserver_base.permissions.base import permission
from trackingserver_base.permissions.permissions import (
user_can_get_dag_template,
user_can_get_dag_templates,
user_can_get_project_by_id,
user_can_update_dag_template,
user_can_write_to_project,
)
from trackingserver_template.models import (
CodeArtifact,
DAGTemplate,
NodeTemplate,
NodeTemplateCodeArtifactRelation,
)
from trackingserver_template.schema import (
CatalogResponse,
CodeArtifactOut,
DAGTemplateIn,
DAGTemplateOut,
DAGTemplateOutWithData,
DAGTemplateUpdate,
NodeTemplateOut,
)
logger = logging.getLogger(__name__)
router = Router(tags=["projects"])
blob_store = blob_storage.get_blob_store()
@router.post("/v1/dag_templates", response=DAGTemplateOut, tags=["projects", "templates"])
@permission(user_can_write_to_project)
async def create_dag_template(
request, project_id: int, dag_template: DAGTemplateIn
) -> DAGTemplateOut:
"""Creates a project version for a given project.
@param request:
@param project_version_id:
@param dag_template:
@return:
"""
user, orgs = request.auth
logger.info(f"Creating DAG template for project version: {project_id} for {user.email}")
code_log = dag_template.code_log
if code_log is not None:
logger.info(f"Saving code log for project {project_id} for {user.email}")
code_log_url = await blob_store.write_obj("project" + str(project_id), code_log.dict())
logger.info(f"Stored code for project {project_id} for {user.email} at {code_log_url}")
code_log_store = blob_store.store()
code_log_schema_version = 1
else:
code_log_url = None
code_log_store = "none"
code_log_schema_version = None
dag_template_created = await DAGTemplate.objects.acreate(
project_id=project_id,
name=dag_template.name,
template_type=dag_template.template_type,
config=dag_template.config,
dag_hash=dag_template.dag_hash,
is_active=True,
tags=dag_template.tags,
code_hash=dag_template.code_hash,
code_version_info_type=dag_template.code_version_info_type,
code_version_info=dag_template.code_version_info,
code_version_info_schema=dag_template.code_version_info_schema,
code_log_store=code_log_store,
code_log_url=code_log_url,
code_log_schema_version=code_log_schema_version,
)
logger.info(f"Created DAG template for project : {project_id} for {user.email}")
code_artifacts_created = await CodeArtifact.objects.abulk_create(
[
CodeArtifact(**code_artifact.dict(), dag_template_id=dag_template_created.id)
for code_artifact in dag_template.code_artifacts
],
ignore_conflicts=False,
)
code_artifacts_by_name = {item.name: item for item in code_artifacts_created}
node_templates_to_create = []
through_models = []
for node in dag_template.nodes:
node_template = NodeTemplate(
**{key: value for key, value in node.dict().items() if key != "code_artifact_pointers"},
dag_template_id=dag_template_created.id,
)
node_templates_to_create.append(node_template)
created_nodes = await NodeTemplate.objects.abulk_create(node_templates_to_create)
logger.info(
f"Created {len(created_nodes)} nodes for project version {project_id} for {user.email}"
)
nodes_created_by_name = {item.name: item for item in created_nodes}
for node in dag_template.nodes:
for i, code_pointer in enumerate(node.code_artifact_pointers):
# TODO -- figure out what to do if this isn't there -- this should be an error, no?
# If so, we should validate earlier...
# For now we're going to log a warning
code_artifact = code_artifacts_by_name.get(code_pointer)
if code_artifact is None:
logger.warning(
f"Code artifact with name: {code_pointer} not found for node: {node.name}"
)
continue
through_models.append(
NodeTemplateCodeArtifactRelation(
node_template_id=nodes_created_by_name[node.name].id,
code_artifact_id=code_artifact.id,
is_primary=i == 0, # First one will be primary
)
)
await NodeTemplateCodeArtifactRelation.objects.abulk_create(through_models)
logger.info(
f"Created {len(created_nodes)} nodes & {len(code_artifacts_created)} artifacts "
f"for project version {project_id} for {user.email}. Created {len(through_models)} relations."
)
return DAGTemplateOut.from_orm(dag_template_created)
@router.get(
"/v1/dag_templates/exists/", response=Optional[DAGTemplateOut], tags=["projects", "templates"]
)
@permission(user_can_get_project_by_id)
async def dag_template_exists(
request,
dag_hash: str,
code_hash: str,
dag_name: str,
project_id: int,
):
"""Checks if a DAG template exists for a given project version and DAG hash.
@param request: The request
@param dag_hash: The DAG hash to check for
@param project_id: The project version to check for
@return: True if the DAG template exists, False otherwise.
"""
try:
logger.info(
f"Checking if DAG template exists for project version: {project_id} with hash: {dag_hash}"
)
dag_template = await DAGTemplate.objects.aget(
project_id=project_id, dag_hash=dag_hash, name=dag_name, code_hash=code_hash
)
return DAGTemplateOut.from_orm(dag_template)
except DAGTemplate.DoesNotExist:
logger.info(f"DAG template does not exist for project: {project_id} with hash: {dag_hash}")
return None
@router.get(
"/v1/dag_templates/latest/", response=List[DAGTemplateOut], tags=["projects", "templates"]
)
@permission(user_can_get_dag_templates)
async def get_latest_dag_templates(
request,
project_id: int,
limit: int = 100,
offset: int = 0,
) -> List[DAGTemplateOut]:
"""Gets all DAG templates for a given project version.
Note that this does not return the nodes, just the templates, as this is a bulk query.
@param request: The request
@param project_id: The project to get DAG templates for
@param limit: The maximum number of DAG templates to return
@param offset: The offset to start at
@return: A list of DAG templates
"""
user, orgs = request.auth
logger.info(f"Getting all DAG templates for project version: {project_id} for {user.email}")
out = [
item
async for item in DAGTemplate.objects.filter(
project_id=project_id, is_active=True
).order_by("-created_at")[offset : offset + limit]
]
logger.info(
f"Got all DAG templates for project version: {project_id} for {user.email}, retrieved {len(out)}"
)
return [DAGTemplateOut.from_orm(dag_template) for dag_template in out]
@router.get(
"/v1/dag_templates/catalog/",
response=CatalogResponse,
tags=["projects", "templates", "FE-only"],
)
@permission(user_can_get_dag_templates)
async def get_dag_template_catalog(
request,
project_id: int = None,
offset: int = 0,
limit: int = 1000,
) -> CatalogResponse:
"""Gets a massive list of the previous node templates -- note that this is a WIP, we may look for it to be distinct, but for
now we just want to get the last few.
Then, when someone opens it up, we'll do a full query for the corresponding node runs to display in the catalog.
This should probably be replaced by a materialized view, which can aggregate distinct across every one, grouping and
getting the first. That said, that's a little complicated.
@param offset:
@param limit:
@param project_id: ID of the associated project
@return: A list of node templates that have been associated with DAG templates with that project.
"""
# TODO -- see how slow this is...
user, orgs = request.auth
t1 = time.time()
logger.info(
f"Getting last {limit} DAG template catalog for project ID: {project_id} for user: {user.email}"
)
# Aggregations
# These will likely do fine, but the best way to do this is to create a materialized view with a catalog table
# This will be running in the background and do this for each project, ideally continually updating
# For now we can aggregate, and let's see how this scales
# We have two options -- leaning towards the second option...
# Option 1:
# We don't have an index on created_at, so I'm just using -id
# Note that this is... suboptimal -- we shouldn't just rely on monotonic IDs, but it'll do what I want for now
# qs = (
# NodeTemplate
# .objects
# .filter(**filter_kwargs)
# .order_by(
# "name",
# "-created_at"
# ).distinct("name")[offset:limit + offset]
# )
# all_node_templates = [item async for item in qs]
# print(await qs.aexplain())
# Option 2:
# Find the latest id for each name
# query = (
# NodeTemplate
# .objects
# .filter(**filter_kwargs)
# .values("name")
# .annotate(max_id=Max('id'))
# .values_list('max_id', flat=True)[offset: limit+offset])
# latest_ids = [
# item async for item in query
# ]
# all_node_templates = [
# item async for item in NodeTemplate.objects.filter(id__in=latest_ids)
# ]
# Note we can optionally do this with the Subquery() function, although I'm not sure if it's faster
# We will use this for now
# Option 3: utilize window functions
# Define the window function
window = Window(expression=RowNumber(), partition_by=F("name"), order_by=F("id").desc())
# Annotate each NodeTemplate with its row number within its name partition
queryset = (
NodeTemplate.objects.filter(
dag_template__project_id=project_id, dag_template__is_active=True
)
.annotate(row_num=window)
.filter(row_num=1)
.prefetch_related("code_artifacts")[offset : limit + offset]
)
all_node_templates = [item async for item in queryset]
# Fetch the result
# print(await queryset.aexplain())
t2 = time.time()
logger.info(
f"Retrieved last {limit} project ID: {project_id} for user: {user.email} in {time.time() - t1} seconds. Got {len(all_node_templates)} nodes."
)
# TODO -- see how slow this is
node_templates = [
NodeTemplateOut.from_orm(node_template) for node_template in all_node_templates
]
code_artifacts = list(
{
code_artifact
for node_template in all_node_templates
for code_artifact in node_template.code_artifacts.all()
}
)
out = CatalogResponse(
nodes=node_templates,
code_artifacts=code_artifacts,
)
logger.info(
f"Parsed last {limit} project ID: {project_id} for user: {user.email} in {time.time() - t2} "
f"seconds. Got {len(out.nodes)} nodes and {len(out.code_artifacts)} code artifacts."
)
return out
@router.get(
"/v1/dag_templates/{str:dag_template_ids}",
response=List[DAGTemplateOutWithData],
tags=["templates"],
)
@permission(user_can_get_dag_template)
async def get_full_dag_templates(request, dag_template_ids: str) -> List[DAGTemplateOutWithData]:
"""Gets the full DAG template, joined with the created nodes, for the given DAG template ID.
@param request: The request
@return: The full DAG template (with nodes).
"""
dag_template_ids_parsed = [int(item) for item in dag_template_ids.split(",")]
out = []
all_nodes = [
item
async for item in NodeTemplate.objects.filter(dag_template_id__in=dag_template_ids_parsed)
.prefetch_related("code_artifacts")
.all()
]
all_code_artifacts = [
item
async for item in CodeArtifact.objects.filter(dag_template_id__in=dag_template_ids_parsed)
]
all_dag_templates_retrieved = await alist(
DAGTemplate.objects.filter(id__in=dag_template_ids_parsed, is_active=True)
)
for dag_template_id in dag_template_ids_parsed:
relevant_dag_templates = [
item for item in all_dag_templates_retrieved if item.id == dag_template_id
]
if len(relevant_dag_templates) == 0:
logger.exception(f"No DAG template with ID: {dag_template_id}")
raise HttpError(status_code=404, message=f"No DAG template with ID: {dag_template_id}")
(dag_template_retrieved,) = relevant_dag_templates
user, orgs = request.auth
logger.info(f"Getting full DAG template for {user.email} with ID: {dag_template_id}")
nodes_in_template = [item for item in all_nodes if item.dag_template_id == dag_template_id]
code_artifacts_in_template = [
item for item in all_code_artifacts if item.dag_template_id == dag_template_id
]
# TODO -- use a bulk query
nodes_out = [NodeTemplateOut.from_orm(node) for node in nodes_in_template]
code_artifacts_out = [
CodeArtifactOut.from_orm(code_artifact) for code_artifact in code_artifacts_in_template
]
dag_template = DAGTemplateOut.from_orm(dag_template_retrieved)
logger.info(
f"Retrieved full DAG template for {user.email} "
f"with ID: {dag_template_id} with {len(nodes_out)} "
f"nodes and {len(code_artifacts_in_template)} code artifacts"
)
# This is a little hacky as we have to get everything in a gather operation
# We have to pass these parameters in as they'll get reassigned to the latest values,
# as python is weird with closures
async def load_dag_template(
dag_template=dag_template, code_artifacts_out=code_artifacts_out, nodes_out=nodes_out
) -> Optional[DAGTemplateOutWithData]:
if dag_template.code_log_store == "none":
return None
# TODO -- assert that the blob store matches he one we have available
code_log = await blob_store.read_obj(dag_template.code_log_url)
return DAGTemplateOutWithData(
**dag_template.dict(),
# TODO -- fix this -- this is due to something weird with the ID names in from_orm
code_artifacts=code_artifacts_out,
nodes=nodes_out,
code=code_log,
)
out.append(load_dag_template())
final_result = list(await asyncio.gather(*out))
return final_result
# This is due to django ninja being weird
# TODO -- find the right way to do this so we can have the same endpoint prefix
@router.put(
"/v1/update_dag_templates/{dag_template_id}",
response=DAGTemplateOut,
tags=["projects", "templates"],
)
@permission(user_can_update_dag_template)
async def update_dag_template(
request, dag_template_id: int, dag_template: DAGTemplateUpdate
) -> DAGTemplateOut:
existing_template = await DAGTemplate.objects.aget(id=dag_template_id)
existing_template.is_active = dag_template.is_active
await existing_template.asave()
return DAGTemplateOut.from_orm(existing_template)