blob: f4ea4bdddf72bfe48300a40ed2ba3961f06ee335 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
from marshmallow import Schema, ValidationError, fields, validate, validates_schema
from marshmallow.utils import get_value
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
from airflow.api_connexion.parameters import validate_istimezone
from airflow.api_connexion.schemas.common_schema import JsonObjectField
from airflow.api_connexion.schemas.enum_schemas import TaskInstanceStateField
from airflow.api_connexion.schemas.job_schema import JobSchema
from airflow.api_connexion.schemas.sla_miss_schema import SlaMissSchema
from airflow.api_connexion.schemas.trigger_schema import TriggerSchema
from airflow.models import TaskInstance
from airflow.utils.helpers import exactly_one
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.models import SlaMiss
class TaskInstanceSchema(SQLAlchemySchema):
"""Task instance schema."""
class Meta:
"""Meta."""
model = TaskInstance
task_id = auto_field()
dag_id = auto_field()
run_id = auto_field(data_key="dag_run_id")
map_index = auto_field()
execution_date = auto_field()
start_date = auto_field()
end_date = auto_field()
duration = auto_field()
state = TaskInstanceStateField()
try_number = auto_field()
max_tries = auto_field()
task_display_name = fields.String(attribute="task_display_name", dump_only=True)
hostname = auto_field()
unixname = auto_field()
pool = auto_field()
pool_slots = auto_field()
queue = auto_field()
priority_weight = auto_field()
operator = auto_field()
queued_dttm = auto_field(data_key="queued_when")
pid = auto_field()
executor_config = auto_field()
note = auto_field()
sla_miss = fields.Nested(SlaMissSchema, dump_default=None)
rendered_map_index = auto_field()
rendered_fields = JsonObjectField(dump_default={})
trigger = fields.Nested(TriggerSchema)
triggerer_job = fields.Nested(JobSchema)
def get_attribute(self, obj, attr, default):
if attr == "sla_miss":
# Object is a tuple of task_instance and slamiss
# and the get_value expects a dict with key, value
# corresponding to the attr.
slamiss_instance = {"sla_miss": obj[1]}
return get_value(slamiss_instance, attr, default)
elif attr == "rendered_fields":
return get_value(obj[0], "rendered_task_instance_fields.rendered_fields", default)
return get_value(obj[0], attr, default)
class TaskInstanceCollection(NamedTuple):
"""List of task instances with metadata."""
task_instances: list[tuple[TaskInstance, SlaMiss | None]]
total_entries: int
class TaskInstanceCollectionSchema(Schema):
"""Task instance collection schema."""
task_instances = fields.List(fields.Nested(TaskInstanceSchema))
total_entries = fields.Int()
class TaskInstanceBatchFormSchema(Schema):
"""Schema for the request form passed to Task Instance Batch endpoint."""
page_offset = fields.Int(load_default=0, validate=validate.Range(min=0))
page_limit = fields.Int(load_default=100, validate=validate.Range(min=1))
dag_ids = fields.List(fields.Str(), load_default=None)
dag_run_ids = fields.List(fields.Str(), load_default=None)
task_ids = fields.List(fields.Str(), load_default=None)
execution_date_gte = fields.DateTime(load_default=None, validate=validate_istimezone)
execution_date_lte = fields.DateTime(load_default=None, validate=validate_istimezone)
start_date_gte = fields.DateTime(load_default=None, validate=validate_istimezone)
start_date_lte = fields.DateTime(load_default=None, validate=validate_istimezone)
end_date_gte = fields.DateTime(load_default=None, validate=validate_istimezone)
end_date_lte = fields.DateTime(load_default=None, validate=validate_istimezone)
duration_gte = fields.Int(load_default=None)
duration_lte = fields.Int(load_default=None)
state = fields.List(fields.Str(allow_none=True), load_default=None)
pool = fields.List(fields.Str(), load_default=None)
queue = fields.List(fields.Str(), load_default=None)
class ClearTaskInstanceFormSchema(Schema):
"""Schema for handling the request of clearing task instance of a Dag."""
dry_run = fields.Boolean(load_default=True)
start_date = fields.DateTime(load_default=None, validate=validate_istimezone)
end_date = fields.DateTime(load_default=None, validate=validate_istimezone)
only_failed = fields.Boolean(load_default=True)
only_running = fields.Boolean(load_default=False)
include_subdags = fields.Boolean(load_default=False)
include_parentdag = fields.Boolean(load_default=False)
reset_dag_runs = fields.Boolean(load_default=False)
task_ids = fields.List(fields.String(), validate=validate.Length(min=1))
dag_run_id = fields.Str(load_default=None)
include_upstream = fields.Boolean(load_default=False)
include_downstream = fields.Boolean(load_default=False)
include_future = fields.Boolean(load_default=False)
include_past = fields.Boolean(load_default=False)
@validates_schema
def validate_form(self, data, **kwargs):
"""Validate clear task instance form."""
if data["only_failed"] and data["only_running"]:
raise ValidationError("only_failed and only_running both are set to True")
if data["start_date"] and data["end_date"]:
if data["start_date"] > data["end_date"]:
raise ValidationError("end_date is sooner than start_date")
if data["start_date"] and data["end_date"] and data["dag_run_id"]:
raise ValidationError("Exactly one of dag_run_id or (start_date and end_date) must be provided")
if data["start_date"] and data["dag_run_id"]:
raise ValidationError("Exactly one of dag_run_id or start_date must be provided")
if data["end_date"] and data["dag_run_id"]:
raise ValidationError("Exactly one of dag_run_id or end_date must be provided")
class SetTaskInstanceStateFormSchema(Schema):
"""Schema for handling the request of setting state of task instance of a DAG."""
dry_run = fields.Boolean(load_default=True)
task_id = fields.Str(required=True)
execution_date = fields.DateTime(validate=validate_istimezone)
dag_run_id = fields.Str()
include_upstream = fields.Boolean(required=True)
include_downstream = fields.Boolean(required=True)
include_future = fields.Boolean(required=True)
include_past = fields.Boolean(required=True)
new_state = TaskInstanceStateField(
required=True,
validate=validate.OneOf(
[TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED]
),
)
@validates_schema
def validate_form(self, data, **kwargs):
"""Validate set task instance state form."""
if not exactly_one(data.get("execution_date"), data.get("dag_run_id")):
raise ValidationError("Exactly one of execution_date or dag_run_id must be provided")
class SetSingleTaskInstanceStateFormSchema(Schema):
"""Schema for handling the request of updating state of a single task instance."""
dry_run = fields.Boolean(load_default=True)
new_state = TaskInstanceStateField(
required=True,
validate=validate.OneOf(
[TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED]
),
)
class TaskInstanceReferenceSchema(Schema):
"""Schema for the task instance reference schema."""
task_id = fields.Str()
run_id = fields.Str(data_key="dag_run_id")
dag_id = fields.Str()
execution_date = fields.DateTime()
class TaskInstanceReferenceCollection(NamedTuple):
"""List of objects with metadata about taskinstance and dag_run_id."""
task_instances: list[tuple[TaskInstance, str]]
class TaskInstanceReferenceCollectionSchema(Schema):
"""Collection schema for task reference."""
task_instances = fields.List(fields.Nested(TaskInstanceReferenceSchema))
class SetTaskInstanceNoteFormSchema(Schema):
"""Schema for settings a note for a TaskInstance."""
# Note: We can't add map_index to the url as subpaths can't start with dashes.
map_index = fields.Int(allow_none=False)
note = fields.String(allow_none=True, validate=validate.Length(max=1000))
task_instance_schema = TaskInstanceSchema()
task_instance_collection_schema = TaskInstanceCollectionSchema()
task_instance_batch_form = TaskInstanceBatchFormSchema()
clear_task_instance_form = ClearTaskInstanceFormSchema()
set_task_instance_state_form = SetTaskInstanceStateFormSchema()
set_single_task_instance_state_form = SetSingleTaskInstanceStateFormSchema()
task_instance_reference_schema = TaskInstanceReferenceSchema()
task_instance_reference_collection_schema = TaskInstanceReferenceCollectionSchema()
set_task_instance_note_form_schema = SetTaskInstanceNoteFormSchema()