blob: 8a02d053578e37fe00753f6d7026411096b2d1f2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import datetime
from traceback import format_exception
from typing import Any, Iterable
from sqlalchemy import Column, Integer, String, func, or_
from sqlalchemy.orm import Session, joinedload, relationship
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.models.base import Base
from airflow.models.taskinstance import TaskInstance
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks
from airflow.utils.state import TaskInstanceState
class Trigger(Base):
"""
Triggers are a workload that run in an asynchronous event loop shared with
other Triggers, and fire off events that will unpause deferred Tasks,
start linked DAGs, etc.
They are persisted into the database and then re-hydrated into a
"triggerer" process, where many are run at once. We model it so that
there is a many-to-one relationship between Task and Trigger, for future
deduplication logic to use.
Rows will be evicted from the database when the triggerer detects no
active Tasks/DAGs using them. Events are not stored in the database;
when an Event is fired, the triggerer will directly push its data to the
appropriate Task/DAG.
"""
__tablename__ = "trigger"
id = Column(Integer, primary_key=True)
classpath = Column(String(1000), nullable=False)
kwargs = Column(ExtendedJSON, nullable=False)
created_date = Column(UtcDateTime, nullable=False)
triggerer_id = Column(Integer, nullable=True)
triggerer_job = relationship(
"Job",
primaryjoin="Job.id == Trigger.triggerer_id",
foreign_keys=triggerer_id,
uselist=False,
)
task_instance = relationship("TaskInstance", back_populates="trigger", lazy="joined", uselist=False)
def __init__(
self,
classpath: str,
kwargs: dict[str, Any],
created_date: datetime.datetime | None = None,
) -> None:
super().__init__()
self.classpath = classpath
self.kwargs = kwargs
self.created_date = created_date or timezone.utcnow()
@classmethod
@internal_api_call
def from_object(cls, trigger: BaseTrigger) -> Trigger:
"""
Alternative constructor that creates a trigger row based directly
off of a Trigger object.
"""
classpath, kwargs = trigger.serialize()
return cls(classpath=classpath, kwargs=kwargs)
@classmethod
@internal_api_call
@provide_session
def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) -> dict[int, Trigger]:
"""
Fetches all the Triggers by ID and returns a dict mapping
ID -> Trigger instance.
"""
query = (
session.query(cls)
.filter(cls.id.in_(ids))
.options(
joinedload("task_instance"),
joinedload("task_instance.trigger"),
joinedload("task_instance.trigger.triggerer_job"),
)
)
return {obj.id: obj for obj in query}
@classmethod
@internal_api_call
@provide_session
def clean_unused(cls, session: Session = NEW_SESSION) -> None:
"""Deletes all triggers that have no tasks dependent on them.
Triggers have a one-to-many relationship to task instances, so we need
to clean those up first. Afterwards we can drop the triggers not
referenced by anyone.
"""
# Update all task instances with trigger IDs that are not DEFERRED to remove them
for attempt in run_with_db_retries():
with attempt:
session.query(TaskInstance).filter(
TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.isnot(None)
).update({TaskInstance.trigger_id: None})
# Get all triggers that have no task instances depending on them...
ids = [
trigger_id
for (trigger_id,) in (
session.query(cls.id)
.join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True)
.group_by(cls.id)
.having(func.count(TaskInstance.trigger_id) == 0)
)
]
# ...and delete them (we can't do this in one query due to MySQL)
session.query(Trigger).filter(Trigger.id.in_(ids)).delete(synchronize_session=False)
@classmethod
@internal_api_call
@provide_session
def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None:
"""
Takes an event from an instance of itself, and triggers all dependent
tasks to resume.
"""
for task_instance in session.query(TaskInstance).filter(
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
):
# Add the event's payload into the kwargs for the task
next_kwargs = task_instance.next_kwargs or {}
next_kwargs["event"] = event.payload
task_instance.next_kwargs = next_kwargs
# Remove ourselves as its trigger
task_instance.trigger_id = None
# Finally, mark it as scheduled so it gets re-queued
task_instance.state = TaskInstanceState.SCHEDULED
@classmethod
@internal_api_call
@provide_session
def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> None:
"""
Called when a trigger has failed unexpectedly, and we need to mark
everything that depended on it as failed. Notably, we have to actually
run the failure code from a worker as it may have linked callbacks, so
hilariously we have to re-schedule the task instances to a worker just
so they can then fail.
We use a special __fail__ value for next_method to achieve this that
the runtime code understands as immediate-fail, and pack the error into
next_kwargs.
TODO: Once we have shifted callback (and email) handling to run on
workers as first-class concepts, we can run the failure code here
in-process, but we can't do that right now.
"""
for task_instance in session.query(TaskInstance).filter(
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
):
# Add the error and set the next_method to the fail state
traceback = format_exception(type(exc), exc, exc.__traceback__) if exc else None
task_instance.next_method = "__fail__"
task_instance.next_kwargs = {"error": "Trigger failure", "traceback": traceback}
# Remove ourselves as its trigger
task_instance.trigger_id = None
# Finally, mark it as scheduled so it gets re-queued
task_instance.state = TaskInstanceState.SCHEDULED
@classmethod
@internal_api_call
@provide_session
def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]:
"""Retrieves a list of triggerer_ids."""
return [row[0] for row in session.query(cls.id).filter(cls.triggerer_id == triggerer_id)]
@classmethod
@internal_api_call
@provide_session
def assign_unassigned(cls, triggerer_id, capacity, session: Session = NEW_SESSION) -> None:
"""
Takes a triggerer_id and the capacity for that triggerer and assigns unassigned
triggers until that capacity is reached, or there are no more unassigned triggers.
"""
from airflow.jobs.job import Job # To avoid circular import
count = session.query(func.count(cls.id)).filter(cls.triggerer_id == triggerer_id).scalar()
capacity -= count
if capacity <= 0:
return
alive_triggerer_ids = [
row[0]
for row in session.query(Job.id).filter(
Job.end_date.is_(None),
Job.latest_heartbeat > timezone.utcnow() - datetime.timedelta(seconds=30),
Job.job_type == "TriggererJob",
)
]
# Find triggers who do NOT have an alive triggerer_id, and then assign
# up to `capacity` of those to us.
trigger_ids_query = cls.get_sorted_triggers(
capacity=capacity, alive_triggerer_ids=alive_triggerer_ids, session=session
)
if trigger_ids_query:
session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update(
{cls.triggerer_id: triggerer_id},
synchronize_session=False,
)
session.commit()
@classmethod
def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session):
return with_row_locks(
session.query(cls.id)
.filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids)))
.order_by(cls.created_date)
.limit(capacity),
session,
skip_locked=True,
).all()