blob: 775ac759ffb649ce0b761ad2b03da1250ba31f53 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import itertools
import logging
import multiprocessing
import os
import signal
import sys
import time
import warnings
from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import lru_cache, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator
from sqlalchemy import and_, func, not_, or_, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Query, Session, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest
from airflow.callbacks.pipe_callback_sink import PipeCallbackSink
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import Job, perform_heartbeat
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.dataset import (
DagScheduleDatasetReference,
DatasetDagRunQueue,
DatasetEvent,
DatasetModel,
TaskOutletDatasetReference,
)
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey
from airflow.stats import Stats
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.timetables.simple import DatasetTriggeredTimetable
from airflow.utils import timezone
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import (
CommitProhibitorGuard,
is_lock_not_available_error,
prohibit_commit,
skip_locked,
tuple_in_condition,
with_row_locks,
)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
if TYPE_CHECKING:
from types import FrameType
from airflow.dag_processing.manager import DagFileProcessorAgent
TI = TaskInstance
DR = DagRun
DM = DagModel
@dataclass
class ConcurrencyMap:
"""
Dataclass to represent concurrency maps.
It contains a map from (dag_id, task_id) to # of task instances, a map from (dag_id, task_id)
to # of task instances in the given state list and a map from (dag_id, run_id, task_id)
to # of task instances in the given state list in each DAG run.
"""
dag_active_tasks_map: dict[str, int]
task_concurrency_map: dict[tuple[str, str], int]
task_dagrun_concurrency_map: dict[tuple[str, str, str], int]
@classmethod
def from_concurrency_map(cls, mapping: dict[tuple[str, str, str], int]) -> ConcurrencyMap:
instance = cls(Counter(), Counter(), Counter(mapping))
for (d, r, t), c in mapping.items():
instance.dag_active_tasks_map[d] += c
instance.task_concurrency_map[(d, t)] += c
return instance
def _is_parent_process() -> bool:
"""
Whether this is a parent process.
Return True if the current process is the parent process.
False if the current process is a child process started by multiprocessing.
"""
return multiprocessing.current_process().name == "MainProcess"
class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
"""
SchedulerJobRunner runs for a specific time interval and schedules jobs that are ready to run.
It figures out the latest runs for each task and sees if the dependencies
for the next schedules are met.
If so, it creates appropriate TaskInstances and sends run commands to the
executor. It does this for each task in each DAG and repeats.
:param subdir: directory containing Python files with Airflow DAG
definitions, or a specific path to a file
:param num_runs: The number of times to run the scheduling loop. If you
have a large number of DAG files this could complete before each file
has been parsed. -1 for unlimited times.
:param num_times_parse_dags: The number of times to try to parse each DAG file.
-1 for unlimited times.
:param scheduler_idle_sleep_time: The number of seconds to wait between
polls of running processors
:param do_pickle: once a DAG object is obtained by executing the Python
file, whether to serialize the DAG object to the DB
:param log: override the default Logger
"""
job_type = "SchedulerJob"
heartrate: int = conf.getint("scheduler", "SCHEDULER_HEARTBEAT_SEC")
def __init__(
self,
job: Job,
subdir: str = settings.DAGS_FOLDER,
num_runs: int = conf.getint("scheduler", "num_runs"),
num_times_parse_dags: int = -1,
scheduler_idle_sleep_time: float = conf.getfloat("scheduler", "scheduler_idle_sleep_time"),
do_pickle: bool = False,
log: logging.Logger | None = None,
processor_poll_interval: float | None = None,
):
super().__init__(job)
self.subdir = subdir
self.num_runs = num_runs
# In specific tests, we want to stop the parse loop after the _files_ have been parsed a certain
# number of times. This is only to support testing, and isn't something a user is likely to want to
# configure -- they'll want num_runs
self.num_times_parse_dags = num_times_parse_dags
if processor_poll_interval:
# TODO: Remove in Airflow 3.0
warnings.warn(
"The 'processor_poll_interval' parameter is deprecated. "
"Please use 'scheduler_idle_sleep_time'.",
RemovedInAirflow3Warning,
stacklevel=2,
)
scheduler_idle_sleep_time = processor_poll_interval
self._scheduler_idle_sleep_time = scheduler_idle_sleep_time
# How many seconds do we wait for tasks to heartbeat before mark them as zombies.
self._zombie_threshold_secs = conf.getint("scheduler", "scheduler_zombie_task_threshold")
self._standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
self._dag_stale_not_seen_duration = conf.getint("scheduler", "dag_stale_not_seen_duration")
# Since the functionality for stalled_task_timeout, task_adoption_timeout, and
# worker_pods_pending_timeout are now handled by a single config (task_queued_timeout),
# we can't deprecate them as we normally would. So, we'll read each config and take
# the max value in order to ensure we're not undercutting a legitimate
# use of any of these configs.
stalled_task_timeout = conf.getfloat("celery", "stalled_task_timeout", fallback=0)
if stalled_task_timeout:
# TODO: Remove in Airflow 3.0
warnings.warn(
"The '[celery] stalled_task_timeout' config option is deprecated. "
"Please update your config to use '[scheduler] task_queued_timeout' instead.",
DeprecationWarning,
)
task_adoption_timeout = conf.getfloat("celery", "task_adoption_timeout", fallback=0)
if task_adoption_timeout:
# TODO: Remove in Airflow 3.0
warnings.warn(
"The '[celery] task_adoption_timeout' config option is deprecated. "
"Please update your config to use '[scheduler] task_queued_timeout' instead.",
DeprecationWarning,
)
worker_pods_pending_timeout = conf.getfloat(
"kubernetes_executor", "worker_pods_pending_timeout", fallback=0
)
if worker_pods_pending_timeout:
# TODO: Remove in Airflow 3.0
warnings.warn(
"The '[kubernetes_executor] worker_pods_pending_timeout' config option is deprecated. "
"Please update your config to use '[scheduler] task_queued_timeout' instead.",
DeprecationWarning,
)
task_queued_timeout = conf.getfloat("scheduler", "task_queued_timeout")
self._task_queued_timeout = max(
stalled_task_timeout, task_adoption_timeout, worker_pods_pending_timeout, task_queued_timeout
)
self.do_pickle = do_pickle
if log:
self._log = log
# Check what SQL backend we use
sql_conn: str = conf.get_mandatory_value("database", "sql_alchemy_conn").lower()
self.using_sqlite = sql_conn.startswith("sqlite")
# Dag Processor agent - not used in Dag Processor standalone mode.
self.processor_agent: DagFileProcessorAgent | None = None
self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False)
self._paused_dag_without_running_dagruns: set = set()
@provide_session
def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
Stats.incr("scheduler_heartbeat", 1, 1)
def register_signals(self) -> None:
"""Register signals that stop child processes."""
signal.signal(signal.SIGINT, self._exit_gracefully)
signal.signal(signal.SIGTERM, self._exit_gracefully)
signal.signal(signal.SIGUSR2, self._debug_dump)
def _exit_gracefully(self, signum: int, frame: FrameType | None) -> None:
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
if not _is_parent_process():
# Only the parent process should perform the cleanup.
return
self.log.info("Exiting gracefully upon receiving signal %s", signum)
if self.processor_agent:
self.processor_agent.end()
sys.exit(os.EX_OK)
def _debug_dump(self, signum: int, frame: FrameType | None) -> None:
if not _is_parent_process():
# Only the parent process should perform the debug dump.
return
try:
sig_name = signal.Signals(signum).name
except Exception:
sig_name = str(signum)
self.log.info("%s\n%s received, printing debug\n%s", "-" * 80, sig_name, "-" * 80)
self.job.executor.debug_dump()
self.log.info("-" * 80)
def __get_concurrency_maps(self, states: Iterable[TaskInstanceState], session: Session) -> ConcurrencyMap:
"""
Get the concurrency maps.
:param states: List of states to query for
:return: Concurrency map
"""
ti_concurrency_query: list[tuple[str, str, str, int]] = (
session.query(TI.task_id, TI.run_id, TI.dag_id, func.count("*"))
.filter(TI.state.in_(states))
.group_by(TI.task_id, TI.run_id, TI.dag_id)
)
return ConcurrencyMap.from_concurrency_map(
{(dag_id, run_id, task_id): count for task_id, run_id, dag_id, count in ti_concurrency_query}
)
def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]:
"""
Find TIs that are ready for execution based on conditions.
Conditions include:
- pool limits
- DAG max_active_tasks
- executor state
- priority
- max active tis per DAG
- max active tis per DAG run
:param max_tis: Maximum number of TIs to queue in this loop.
:return: list[airflow.models.TaskInstance]
"""
from airflow.models.pool import Pool
from airflow.utils.db import DBLocks
executable_tis: list[TI] = []
if session.get_bind().dialect.name == "postgresql":
# Optimization: to avoid littering the DB errors of "ERROR: canceling statement due to lock
# timeout", try to take out a transactional advisory lock (unlocks automatically on
# COMMIT/ROLLBACK)
lock_acquired = session.execute(
text("SELECT pg_try_advisory_xact_lock(:id)").bindparams(
id=DBLocks.SCHEDULER_CRITICAL_SECTION.value
)
).scalar()
if not lock_acquired:
# Throw an error like the one that would happen with NOWAIT
raise OperationalError(
"Failed to acquire advisory lock", params=None, orig=RuntimeError("55P03")
)
# Get the pool settings. We get a lock on the pool rows, treating this as a "critical section"
# Throws an exception if lock cannot be obtained, rather than blocking
pools = Pool.slots_stats(lock_rows=True, session=session)
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
max_tis = min(max_tis, pool_slots_free)
starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0}
# dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
concurrency_map = self.__get_concurrency_maps(states=EXECUTION_STATES, session=session)
# Number of tasks that cannot be scheduled because of no open slot in pool
num_starving_tasks_total = 0
# dag and task ids that can't be queued because of concurrency limits
starved_dags: set[str] = set()
starved_tasks: set[tuple[str, str]] = set()
starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] = set()
pool_num_starving_tasks: dict[str, int] = Counter()
for loop_count in itertools.count(start=1):
num_starved_pools = len(starved_pools)
num_starved_dags = len(starved_dags)
num_starved_tasks = len(starved_tasks)
num_starved_tasks_task_dagrun_concurrency = len(starved_tasks_task_dagrun_concurrency)
# Get task instances associated with scheduled
# DagRuns which are not backfilled, in the given states,
# and the dag is not paused
query = (
session.query(TI)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(TI.dag_run)
.filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.filter(not_(DM.is_paused))
.filter(TI.state == TaskInstanceState.SCHEDULED)
.options(selectinload("dag_model"))
.order_by(-TI.priority_weight, DR.execution_date, TI.map_index)
)
if starved_pools:
query = query.filter(not_(TI.pool.in_(starved_pools)))
if starved_dags:
query = query.filter(not_(TI.dag_id.in_(starved_dags)))
if starved_tasks:
task_filter = tuple_in_condition((TI.dag_id, TI.task_id), starved_tasks)
query = query.filter(not_(task_filter))
if starved_tasks_task_dagrun_concurrency:
task_filter = tuple_in_condition(
(TI.dag_id, TI.run_id, TI.task_id),
starved_tasks_task_dagrun_concurrency,
)
query = query.filter(not_(task_filter))
query = query.limit(max_tis)
timer = Stats.timer("scheduler.critical_section_query_duration")
timer.start()
try:
task_instances_to_examine: list[TI] = with_row_locks(
query,
of=TI,
session=session,
**skip_locked(session=session),
).all()
timer.stop(send=True)
except OperationalError as e:
timer.stop(send=False)
raise e
# TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
# Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine))
if len(task_instances_to_examine) == 0:
self.log.debug("No tasks to consider for execution.")
break
# Put one task instance on each line
task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
self.log.info(
"%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str
)
for task_instance in task_instances_to_examine:
pool_name = task_instance.pool
pool_stats = pools.get(pool_name)
if not pool_stats:
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
starved_pools.add(pool_name)
continue
# Make sure to emit metrics if pool has no starving tasks
pool_num_starving_tasks.setdefault(pool_name, 0)
pool_total = pool_stats["total"]
open_slots = pool_stats["open"]
if open_slots <= 0:
self.log.info(
"Not scheduling since there are %s open slots in pool %s", open_slots, pool_name
)
# Can't schedule any more since there are no more open slots.
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_pools.add(pool_name)
continue
if task_instance.pool_slots > pool_total:
self.log.warning(
"Not executing %s. Requested pool slots (%s) are greater than "
"total pool slots: '%s' for pool: %s.",
task_instance,
task_instance.pool_slots,
pool_total,
pool_name,
)
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue
if task_instance.pool_slots > open_slots:
self.log.info(
"Not executing %s since it requires %s slots "
"but there are %s open slots in the pool %s.",
task_instance,
task_instance.pool_slots,
open_slots,
pool_name,
)
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
# Though we can execute tasks with lower priority if there's enough room
continue
# Check to make sure that the task max_active_tasks of the DAG hasn't been
# reached.
dag_id = task_instance.dag_id
current_active_tasks_per_dag = concurrency_map.dag_active_tasks_map[dag_id]
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
self.log.info(
"DAG %s has %s/%s running and queued tasks",
dag_id,
current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
self.log.info(
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
task_instance,
dag_id,
max_active_tasks_per_dag_limit,
)
starved_dags.add(dag_id)
continue
if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
"DAG '%s' for task instance %s not found in serialized_dag table",
dag_id,
task_instance,
)
session.query(TI).filter(
TI.dag_id == dag_id, TI.state == TaskInstanceState.SCHEDULED
).update({TI.state: TaskInstanceState.FAILED}, synchronize_session="fetch")
continue
task_concurrency_limit: int | None = None
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dag
if task_concurrency_limit is not None:
current_task_concurrency = concurrency_map.task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]
if current_task_concurrency >= task_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency for"
" this task has been reached.",
task_instance,
)
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue
task_dagrun_concurrency_limit: int | None = None
if serialized_dag.has_task(task_instance.task_id):
task_dagrun_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dagrun
if task_dagrun_concurrency_limit is not None:
current_task_dagrun_concurrency = concurrency_map.task_dagrun_concurrency_map[
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)
]
if current_task_dagrun_concurrency >= task_dagrun_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency per DAG run for"
" this task has been reached.",
task_instance,
)
starved_tasks_task_dagrun_concurrency.add(
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)
)
continue
executable_tis.append(task_instance)
open_slots -= task_instance.pool_slots
concurrency_map.dag_active_tasks_map[dag_id] += 1
concurrency_map.task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1
concurrency_map.task_dagrun_concurrency_map[
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)
] += 1
pool_stats["open"] = open_slots
is_done = executable_tis or len(task_instances_to_examine) < max_tis
# Check this to avoid accidental infinite loops
found_new_filters = (
len(starved_pools) > num_starved_pools
or len(starved_dags) > num_starved_dags
or len(starved_tasks) > num_starved_tasks
or len(starved_tasks_task_dagrun_concurrency) > num_starved_tasks_task_dagrun_concurrency
)
if is_done or not found_new_filters:
break
self.log.info(
"Found no task instances to queue on query iteration %s "
"but there could be more candidate task instances to check.",
loop_count,
)
for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
Stats.gauge(f"pool.starving_tasks.{pool_name}", num_starving_tasks)
Stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
Stats.gauge("scheduler.tasks.executable", len(executable_tis))
if len(executable_tis) > 0:
task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str)
# set TIs to queued state
filter_for_tis = TI.filter_for_tis(executable_tis)
session.query(TI).filter(filter_for_tis).update(
# TODO[ha]: should we use func.now()? How does that work with DB timezone
# on mysql when it's not UTC?
{
TI.state: TaskInstanceState.QUEUED,
TI.queued_dttm: timezone.utcnow(),
TI.queued_by_job_id: self.job.id,
},
synchronize_session=False,
)
for ti in executable_tis:
ti.emit_state_change_metric(State.QUEUED)
for ti in executable_tis:
make_transient(ti)
return executable_tis
def _enqueue_task_instances_with_queued_state(self, task_instances: list[TI], session: Session) -> None:
"""
Enqueue task_instances which should have been set to queued with the executor.
:param task_instances: TaskInstances to enqueue
:param session: The session object
"""
# actually enqueue them
for ti in task_instances:
if ti.dag_run.state in State.finished:
ti.set_state(State.NONE, session=session)
continue
command = ti.command_as_list(
local=True,
pickle_id=ti.dag_model.pickle_id,
)
priority = ti.priority_weight
queue = ti.queue
self.log.info("Sending %s to executor with priority %s and queue %s", ti.key, priority, queue)
self.job.executor.queue_command(
ti,
command,
priority=priority,
queue=queue,
)
def _critical_section_enqueue_task_instances(self, session: Session) -> int:
"""
Enqueues TaskInstances for execution.
There are three steps:
1. Pick TIs by priority with the constraint that they are in the expected states
and that we do exceed max_active_runs or pool limits.
2. Change the state for the TIs above atomically.
3. Enqueue the TIs in the executor.
HA note: This function is a "critical section" meaning that only a single executor process can execute
this function at the same time. This is achieved by doing ``SELECT ... from pool FOR UPDATE``. For DBs
that support NOWAIT, a "blocked" scheduler will skip this and continue on with other tasks (creating
new DAG runs, progressing TIs from None to SCHEDULED etc.); DBs that don't support this (such as
MariaDB or MySQL 5.x) the other schedulers will wait for the lock before continuing.
:param session:
:return: Number of task instance with state changed.
"""
if self.job.max_tis_per_query == 0:
max_tis = self.job.executor.slots_available
else:
max_tis = min(self.job.max_tis_per_query, self.job.executor.slots_available)
queued_tis = self._executable_task_instances_to_queued(max_tis, session=session)
self._enqueue_task_instances_with_queued_state(queued_tis, session=session)
return len(queued_tis)
def _process_executor_events(self, session: Session) -> int:
"""Respond to executor events."""
if not self._standalone_dag_processor and not self.processor_agent:
raise ValueError("Processor agent is not started.")
ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int] = {}
event_buffer = self.job.executor.get_event_buffer()
tis_with_right_state: list[TaskInstanceKey] = []
# Report execution
for ti_key, value in event_buffer.items():
state: str
state, _ = value
# We create map (dag_id, task_id, execution_date) -> in-memory try_number
ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number
self.log.info("Received executor event with state %s for task instance %s", state, ti_key)
if state in (State.FAILED, State.SUCCESS, State.QUEUED):
tis_with_right_state.append(ti_key)
# Return if no finished tasks
if not tis_with_right_state:
return len(event_buffer)
# Check state of finished tasks
filter_for_tis = TI.filter_for_tis(tis_with_right_state)
query = session.query(TI).filter(filter_for_tis).options(selectinload("dag_model"))
# row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have
# multi-schedulers
tis: Iterator[TI] = with_row_locks(
query,
of=TI,
session=session,
**skip_locked(session=session),
)
for ti in tis:
try_number = ti_primary_key_to_try_number_map[ti.key.primary]
buffer_key = ti.key.with_try_number(try_number)
state, info = event_buffer.pop(buffer_key)
if state == State.QUEUED:
ti.external_executor_id = info
self.log.info("Setting external_id for %s to %s", ti, info)
continue
msg = (
"TaskInstance Finished: dag_id=%s, task_id=%s, run_id=%s, map_index=%s, "
"run_start_date=%s, run_end_date=%s, "
"run_duration=%s, state=%s, executor_state=%s, try_number=%s, max_tries=%s, job_id=%s, "
"pool=%s, queue=%s, priority_weight=%d, operator=%s, queued_dttm=%s, "
"queued_by_job_id=%s, pid=%s"
)
self.log.info(
msg,
ti.dag_id,
ti.task_id,
ti.run_id,
ti.map_index,
ti.start_date,
ti.end_date,
ti.duration,
ti.state,
state,
try_number,
ti.max_tries,
ti.job_id,
ti.pool,
ti.queue,
ti.priority_weight,
ti.operator,
ti.queued_dttm,
ti.queued_by_job_id,
ti.pid,
)
# There are two scenarios why the same TI with the same try_number is queued
# after executor is finished with it:
# 1) the TI was killed externally and it had no time to mark itself failed
# - in this case we should mark it as failed here.
# 2) the TI has been requeued after getting deferred - in this case either our executor has it
# or the TI is queued by another job. Either ways we should not fail it.
# All of this could also happen if the state is "running",
# but that is handled by the zombie detection.
ti_queued = ti.try_number == buffer_key.try_number and ti.state == TaskInstanceState.QUEUED
ti_requeued = (
ti.queued_by_job_id != self.job.id # Another scheduler has queued this task again
or self.job.executor.has_task(ti) # This scheduler has this task already
)
if ti_queued and not ti_requeued:
Stats.incr(
"scheduler.tasks.killed_externally",
tags={"dag_id": ti.dag_id, "task_id": ti.task_id},
)
msg = (
"Executor reports task instance %s finished (%s) although the "
"task says its %s. (Info: %s) Was the task killed externally?"
)
self.log.error(msg, ti, state, ti.state, info)
# Get task from the Serialized DAG
try:
dag = self.dagbag.get_dag(ti.dag_id)
task = dag.get_task(ti.task_id)
except Exception:
self.log.exception("Marking task instance %s as %s", ti, state)
ti.set_state(state)
continue
ti.task = task
if task.on_retry_callback or task.on_failure_callback:
request = TaskCallbackRequest(
full_filepath=ti.dag_model.fileloc,
simple_task_instance=SimpleTaskInstance.from_ti(ti),
msg=msg % (ti, state, ti.state, info),
processor_subdir=ti.dag_model.processor_subdir,
)
self.job.executor.send_callback(request)
else:
ti.handle_failure(error=msg % (ti, state, ti.state, info), session=session)
return len(event_buffer)
def _execute(self) -> int | None:
from airflow.dag_processing.manager import DagFileProcessorAgent
self.log.info("Starting the scheduler")
executor_class, _ = ExecutorLoader.import_default_executor_cls()
# DAGs can be pickled for easier remote execution by some executors
pickle_dags = self.do_pickle and executor_class.supports_pickling
self.log.info("Processing each file at most %s times", self.num_times_parse_dags)
# When using sqlite, we do not use async_mode
# so the scheduler job and DAG parser don't access the DB at the same time.
async_mode = not self.using_sqlite
processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)
if not self._standalone_dag_processor:
self.processor_agent = DagFileProcessorAgent(
dag_directory=Path(self.subdir),
max_runs=self.num_times_parse_dags,
processor_timeout=processor_timeout,
dag_ids=[],
pickle_dags=pickle_dags,
async_mode=async_mode,
)
try:
self.job.executor.job_id = self.job.id
if self.processor_agent:
self.log.debug("Using PipeCallbackSink as callback sink.")
self.job.executor.callback_sink = PipeCallbackSink(
get_sink_pipe=self.processor_agent.get_callbacks_pipe
)
else:
from airflow.callbacks.database_callback_sink import DatabaseCallbackSink
self.log.debug("Using DatabaseCallbackSink as callback sink.")
self.job.executor.callback_sink = DatabaseCallbackSink()
self.job.executor.start()
self.register_signals()
if self.processor_agent:
self.processor_agent.start()
execute_start_time = timezone.utcnow()
self._run_scheduler_loop()
if self.processor_agent:
# Stop any processors
self.processor_agent.terminate()
# Verify that all files were processed, and if so, deactivate DAGs that
# haven't been touched by the scheduler as they likely have been
# deleted.
if self.processor_agent.all_files_processed:
self.log.info(
"Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat()
)
DAG.deactivate_stale_dags(execute_start_time)
settings.Session.remove() # type: ignore
except Exception:
self.log.exception("Exception when executing SchedulerJob._run_scheduler_loop")
raise
finally:
try:
self.job.executor.end()
except Exception:
self.log.exception("Exception when executing Executor.end")
if self.processor_agent:
try:
self.processor_agent.end()
except Exception:
self.log.exception("Exception when executing DagFileProcessorAgent.end")
self.log.info("Exited execute loop")
return None
@provide_session
def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION) -> None:
try:
paused_runs = (
session.query(DagRun)
.join(DagRun.dag_model)
.join(TI)
.filter(
DagModel.is_paused == expression.true(),
DagRun.state == DagRunState.RUNNING,
DagRun.run_type != DagRunType.BACKFILL_JOB,
)
.having(DagRun.last_scheduling_decision <= func.max(TI.updated_at))
.group_by(DagRun)
)
for dag_run in paused_runs:
dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
if dag is None:
continue
dag_run.dag = dag
_, callback_to_run = dag_run.update_state(execute_callbacks=False, session=session)
if callback_to_run:
self._send_dag_callbacks_to_processor(dag, callback_to_run)
except Exception as e: # should not fail the scheduler
self.log.exception("Failed to update dag run state for paused dags due to %s", str(e))
def _run_scheduler_loop(self) -> None:
"""
The actual scheduler loop.
The main steps in the loop are:
#. Harvest DAG parsing results through DagFileProcessorAgent
#. Find and queue executable tasks
#. Change task instance state in DB
#. Queue tasks in executor
#. Heartbeat executor
#. Execute queued tasks in executor asynchronously
#. Sync on the states of running tasks
Following is a graphic representation of these steps.
.. image:: ../docs/apache-airflow/img/scheduler_loop.jpg
"""
if not self.processor_agent and not self._standalone_dag_processor:
raise ValueError("Processor agent is not started.")
is_unit_test: bool = conf.getboolean("core", "unit_test_mode")
timers = EventScheduler()
# Check on start up, then every configured interval
self.adopt_or_reset_orphaned_tasks()
timers.call_regular_interval(
conf.getfloat("scheduler", "orphaned_tasks_check_interval", fallback=300.0),
self.adopt_or_reset_orphaned_tasks,
)
timers.call_regular_interval(
conf.getfloat("scheduler", "trigger_timeout_check_interval", fallback=15.0),
self.check_trigger_timeouts,
)
timers.call_regular_interval(
conf.getfloat("scheduler", "pool_metrics_interval", fallback=5.0),
self._emit_pool_metrics,
)
timers.call_regular_interval(
conf.getfloat("scheduler", "zombie_detection_interval", fallback=10.0),
self._find_zombies,
)
timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags)
timers.call_regular_interval(
conf.getfloat("scheduler", "task_queued_timeout_check_interval"),
self._fail_tasks_stuck_in_queued,
)
timers.call_regular_interval(
conf.getfloat("scheduler", "parsing_cleanup_interval"),
self._orphan_unreferenced_datasets,
)
if self._standalone_dag_processor:
timers.call_regular_interval(
conf.getfloat("scheduler", "parsing_cleanup_interval"),
self._cleanup_stale_dags,
)
for loop_count in itertools.count(start=1):
with Stats.timer("scheduler.scheduler_loop_duration") as timer:
if self.using_sqlite and self.processor_agent:
self.processor_agent.run_single_parsing_loop()
# For the sqlite case w/ 1 thread, wait until the processor
# is finished to avoid concurrent access to the DB.
self.log.debug("Waiting for processors to finish since we're using sqlite")
self.processor_agent.wait_until_finished()
with create_session() as session:
num_queued_tis = self._do_scheduling(session)
self.job.executor.heartbeat()
session.expunge_all()
num_finished_events = self._process_executor_events(session=session)
if self.processor_agent:
self.processor_agent.heartbeat()
# Heartbeat the scheduler periodically
perform_heartbeat(
job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True
)
# Run any pending timed events
next_event = timers.run(blocking=False)
self.log.debug("Next timed event is in %f", next_event)
self.log.debug("Ran scheduling loop in %.2f seconds", timer.duration)
if not is_unit_test and not num_queued_tis and not num_finished_events:
# If the scheduler is doing things, don't sleep. This means when there is work to do, the
# scheduler will run "as quick as possible", but when it's stopped, it can sleep, dropping CPU
# usage when "idle"
time.sleep(min(self._scheduler_idle_sleep_time, next_event if next_event else 0))
if loop_count >= self.num_runs > 0:
self.log.info(
"Exiting scheduler loop as requested number of runs (%d - got to %d) has been reached",
self.num_runs,
loop_count,
)
break
if self.processor_agent and self.processor_agent.done:
self.log.info(
"Exiting scheduler loop as requested DAG parse count (%d) has been reached after %d"
" scheduler loops",
self.num_times_parse_dags,
loop_count,
)
break
def _do_scheduling(self, session: Session) -> int:
"""
This function is where the main scheduling decisions take places.
It:
- Creates any necessary DAG runs by examining the next_dagrun_create_after column of DagModel
Since creating Dag Runs is a relatively time consuming process, we select only 10 dags by default
(configurable via ``scheduler.max_dagruns_to_create_per_loop`` setting) - putting this higher will
mean one scheduler could spend a chunk of time creating dag runs, and not ever get around to
scheduling tasks.
- Finds the "next n oldest" running DAG Runs to examine for scheduling (n=20 by default, configurable
via ``scheduler.max_dagruns_per_loop_to_schedule`` config setting) and tries to progress state (TIs
to SCHEDULED, or DagRuns to SUCCESS/FAILURE etc)
By "next oldest", we mean hasn't been examined/scheduled in the most time.
We don't select all dagruns at once, because the rows are selected with row locks, meaning
that only one scheduler can "process them", even it is waiting behind other dags. Increasing this
limit will allow more throughput for smaller DAGs but will likely slow down throughput for larger
(>500 tasks.) DAGs
- Then, via a Critical Section (locking the rows of the Pool model) we queue tasks, and then send them
to the executor.
See docs of _critical_section_enqueue_task_instances for more.
:return: Number of TIs enqueued in this iteration
"""
# Put a check in place to make sure we don't commit unexpectedly
with prohibit_commit(session) as guard:
if settings.USE_JOB_SCHEDULE:
self._create_dagruns_for_dags(guard, session)
self._start_queued_dagruns(session)
guard.commit()
dag_runs = self._get_next_dagruns_to_examine(DagRunState.RUNNING, session)
# Bulk fetch the currently active dag runs for the dags we are
# examining, rather than making one query per DagRun
callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session)
# Send the callbacks after we commit to ensure the context is up to date when it gets run
# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
)
for dag_run, callback_to_run in callback_tuples:
dag = cached_get_dag(dag_run.dag_id)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
# Sending callbacks there as in standalone_dag_processor they are adding to the database,
# so it must be done outside of prohibit_commit.
self._send_dag_callbacks_to_processor(dag, callback_to_run)
with prohibit_commit(session) as guard:
# Without this, the session has an invalid view of the DB
session.expunge_all()
# END: schedule TIs
if self.job.executor.slots_available <= 0:
# We know we can't do anything here, so don't even try!
self.log.debug("Executor full, skipping critical section")
num_queued_tis = 0
else:
try:
timer = Stats.timer("scheduler.critical_section_duration")
timer.start()
# Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor)
num_queued_tis = self._critical_section_enqueue_task_instances(session=session)
# Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the
# metric, way down
timer.stop(send=True)
except OperationalError as e:
timer.stop(send=False)
if is_lock_not_available_error(error=e):
self.log.debug("Critical section lock held by another Scheduler")
Stats.incr("scheduler.critical_section_busy")
session.rollback()
return 0
raise
guard.commit()
return num_queued_tis
@retry_db_transaction
def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session) -> Query:
"""Get Next DagRuns to Examine with retries."""
return DagRun.next_dagruns_to_examine(state, session)
@retry_db_transaction
def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Session) -> None:
"""Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError."""
query, dataset_triggered_dag_info = DagModel.dags_needing_dagruns(session)
all_dags_needing_dag_runs = set(query.all())
dataset_triggered_dags = [
dag for dag in all_dags_needing_dag_runs if dag.dag_id in dataset_triggered_dag_info
]
non_dataset_dags = all_dags_needing_dag_runs.difference(dataset_triggered_dags)
self._create_dag_runs(non_dataset_dags, session)
if dataset_triggered_dags:
self._create_dag_runs_dataset_triggered(
dataset_triggered_dags, dataset_triggered_dag_info, session
)
# commit the session - Release the write lock on DagModel table.
guard.commit()
# END: create dagruns
def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -> None:
"""
Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control
if/when the next DAGRun should be created.
"""
# Bulk Fetch DagRuns with dag_id and execution_date same
# as DagModel.dag_id and DagModel.next_dagrun
# This list is used to verify if the DagRun already exist so that we don't attempt to create
# duplicate dag runs
existing_dagruns = (
session.query(DagRun.dag_id, DagRun.execution_date)
.filter(
tuple_in_condition(
(DagRun.dag_id, DagRun.execution_date),
((dm.dag_id, dm.next_dagrun) for dm in dag_models),
),
)
.all()
)
active_runs_of_dags = Counter(
DagRun.active_runs_of_dags(dag_ids=(dm.dag_id for dm in dag_models), session=session),
)
for dag_model in dag_models:
dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue
dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
data_interval = dag.get_next_data_interval(dag_model)
# Explicitly check if the DagRun already exists. This is an edge case
# where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after`
# are not updated.
# We opted to check DagRun existence instead
# of catching an Integrity error and rolling back the session i.e
# we need to set dag.next_dagrun_info if the Dag Run already exists or if we
# create a new one. This is so that in the next Scheduling loop we try to create new runs
# instead of falling in a loop of Integrity Error.
if (dag.dag_id, dag_model.next_dagrun) not in existing_dagruns:
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=dag_model.next_dagrun,
state=DagRunState.QUEUED,
data_interval=data_interval,
external_trigger=False,
session=session,
dag_hash=dag_hash,
creating_job_id=self.job.id,
)
active_runs_of_dags[dag.dag_id] += 1
if self._should_update_dag_next_dagruns(dag, dag_model, active_runs_of_dags[dag.dag_id]):
dag_model.calculate_dagrun_date_fields(dag, data_interval)
# TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in
# memory for larger dags? or expunge_all()
def _create_dag_runs_dataset_triggered(
self,
dag_models: Collection[DagModel],
dataset_triggered_dag_info: dict[str, tuple[datetime, datetime]],
session: Session,
) -> None:
"""For DAGs that are triggered by datasets, create dag runs."""
# Bulk Fetch DagRuns with dag_id and execution_date same
# as DagModel.dag_id and DagModel.next_dagrun
# This list is used to verify if the DagRun already exist so that we don't attempt to create
# duplicate dag runs
exec_dates = {
dag_id: timezone.coerce_datetime(last_time)
for dag_id, (_, last_time) in dataset_triggered_dag_info.items()
}
existing_dagruns: set[tuple[str, timezone.DateTime]] = set(
session.query(DagRun.dag_id, DagRun.execution_date).filter(
tuple_in_condition((DagRun.dag_id, DagRun.execution_date), exec_dates.items())
)
)
for dag_model in dag_models:
dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue
if not isinstance(dag.timetable, DatasetTriggeredTimetable):
self.log.error(
"DAG '%s' was dataset-scheduled, but didn't have a DatasetTriggeredTimetable!",
dag_model.dag_id,
)
continue
dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
# Explicitly check if the DagRun already exists. This is an edge case
# where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after`
# are not updated.
# We opted to check DagRun existence instead
# of catching an Integrity error and rolling back the session i.e
# we need to set dag.next_dagrun_info if the Dag Run already exists or if we
# create a new one. This is so that in the next Scheduling loop we try to create new runs
# instead of falling in a loop of Integrity Error.
exec_date = exec_dates[dag.dag_id]
if (dag.dag_id, exec_date) not in existing_dagruns:
previous_dag_run = (
session.query(DagRun)
.filter(
DagRun.dag_id == dag.dag_id,
DagRun.execution_date < exec_date,
DagRun.run_type == DagRunType.DATASET_TRIGGERED,
)
.order_by(DagRun.execution_date.desc())
.first()
)
dataset_event_filters = [
DagScheduleDatasetReference.dag_id == dag.dag_id,
DatasetEvent.timestamp <= exec_date,
]
if previous_dag_run:
dataset_event_filters.append(DatasetEvent.timestamp > previous_dag_run.execution_date)
dataset_events = (
session.query(DatasetEvent)
.join(
DagScheduleDatasetReference,
DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id,
)
.join(DatasetEvent.source_dag_run)
.filter(*dataset_event_filters)
.all()
)
data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events)
run_id = dag.timetable.generate_run_id(
run_type=DagRunType.DATASET_TRIGGERED,
logical_date=exec_date,
data_interval=data_interval,
session=session,
events=dataset_events,
)
dag_run = dag.create_dagrun(
run_id=run_id,
run_type=DagRunType.DATASET_TRIGGERED,
execution_date=exec_date,
data_interval=data_interval,
state=DagRunState.QUEUED,
external_trigger=False,
session=session,
dag_hash=dag_hash,
creating_job_id=self.job.id,
)
Stats.incr("dataset.triggered_dagruns")
dag_run.consumed_dataset_events.extend(dataset_events)
session.query(DatasetDagRunQueue).filter(
DatasetDagRunQueue.target_dag_id == dag_run.dag_id
).delete()
def _should_update_dag_next_dagruns(self, dag: DAG, dag_model: DagModel, total_active_runs: int) -> bool:
"""Check if the dag's next_dagruns_create_after should be updated."""
if total_active_runs >= dag.max_active_runs:
self.log.info(
"DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
dag_model.dag_id,
total_active_runs,
dag.max_active_runs,
)
dag_model.next_dagrun_create_after = None
return False
return True
def _start_queued_dagruns(self, session: Session) -> None:
"""Find DagRuns in queued state and decide moving them to running state."""
dag_runs: Collection[DagRun] = self._get_next_dagruns_to_examine(DagRunState.QUEUED, session)
active_runs_of_dags = Counter(
DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs), only_running=True, session=session),
)
def _update_state(dag: DAG, dag_run: DagRun):
dag_run.state = DagRunState.RUNNING
dag_run.start_date = timezone.utcnow()
if dag.timetable.periodic:
# TODO: Logically, this should be DagRunInfo.run_after, but the
# information is not stored on a DagRun, only before the actual
# execution on DagModel.next_dagrun_create_after. We should add
# a field on DagRun for this instead of relying on the run
# always happening immediately after the data interval.
expected_start_date = dag.get_run_data_interval(dag_run).end
schedule_delay = dag_run.start_date - expected_start_date
# Publish metrics twice with backward compatible name, and then with tags
Stats.timing(f"dagrun.schedule_delay.{dag.dag_id}", schedule_delay)
Stats.timing(
"dagrun.schedule_delay",
schedule_delay,
tags={"dag_id": dag.dag_id},
)
# cache saves time during scheduling of many dag_runs for same dag
cached_get_dag: Callable[[str], DAG | None] = lru_cache()(
partial(self.dagbag.get_dag, session=session)
)
for dag_run in dag_runs:
dag = dag_run.dag = cached_get_dag(dag_run.dag_id)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
active_runs = active_runs_of_dags[dag_run.dag_id]
if dag.max_active_runs and active_runs >= dag.max_active_runs:
self.log.debug(
"DAG %s already has %d active runs, not moving any more runs to RUNNING state %s",
dag.dag_id,
active_runs,
dag_run.execution_date,
)
else:
active_runs_of_dags[dag_run.dag_id] += 1
_update_state(dag, dag_run)
dag_run.notify_dagrun_state_changed()
@retry_db_transaction
def _schedule_all_dag_runs(
self,
guard: CommitProhibitorGuard,
dag_runs: Iterable[DagRun],
session: Session,
) -> list[tuple[DagRun, DagCallbackRequest | None]]:
"""Makes scheduling decisions for all `dag_runs`."""
callback_tuples = [(run, self._schedule_dag_run(run, session=session)) for run in dag_runs]
guard.commit()
return callback_tuples
def _schedule_dag_run(
self,
dag_run: DagRun,
session: Session,
) -> DagCallbackRequest | None:
"""
Make scheduling decisions about an individual dag run.
:param dag_run: The DagRun to schedule
:return: Callback that needs to be executed
"""
callback: DagCallbackRequest | None = None
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag_model = DM.get_dagmodel(dag_run.dag_id, session)
if not dag or not dag_model:
self.log.error("Couldn't find DAG %s in DAG bag or database!", dag_run.dag_id)
return callback
if (
dag_run.start_date
and dag.dagrun_timeout
and dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout
):
dag_run.set_state(DagRunState.FAILED)
unfinished_task_instances = (
session.query(TI)
.filter(TI.dag_id == dag_run.dag_id)
.filter(TI.run_id == dag_run.run_id)
.filter(TI.state.in_(State.unfinished))
)
for task_instance in unfinished_task_instances:
task_instance.state = TaskInstanceState.SKIPPED
session.merge(task_instance)
session.flush()
self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id)
active_runs = dag.get_num_active_runs(only_running=False, session=session)
# Work out if we should allow creating a new DagRun now?
if self._should_update_dag_next_dagruns(dag, dag_model, active_runs):
dag_model.calculate_dagrun_date_fields(dag, dag.get_run_data_interval(dag_run))
callback_to_execute = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=dag.dag_id,
run_id=dag_run.run_id,
is_failure_callback=True,
processor_subdir=dag_model.processor_subdir,
msg="timed_out",
)
dag_run.notify_dagrun_state_changed()
duration = dag_run.end_date - dag_run.start_date
Stats.timing(f"dagrun.duration.failed.{dag_run.dag_id}", duration)
Stats.timing("dagrun.duration.failed", duration, tags={"dag_id": dag_run.dag_id})
return callback_to_execute
if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates:
self.log.error("Execution date is in future: %s", dag_run.execution_date)
return callback
if not self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session):
self.log.warning("The DAG disappeared before verifying integrity: %s. Skipping.", dag_run.dag_id)
return callback
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
if dag_run.state in State.finished:
active_runs = dag.get_num_active_runs(only_running=False, session=session)
# Work out if we should allow creating a new DagRun now?
if self._should_update_dag_next_dagruns(dag, dag_model, active_runs):
dag_model.calculate_dagrun_date_fields(dag, dag.get_run_data_interval(dag_run))
# This will do one query per dag run. We "could" build up a complex
# query to update all the TIs across all the execution dates and dag
# IDs in a single query, but it turns out that can be _very very slow_
# see #11147/commit ee90807ac for more details
dag_run.schedule_tis(schedulable_tis, session, max_tis_per_query=self.job.max_tis_per_query)
return callback_to_run
def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) -> bool:
"""
Only run DagRun.verify integrity if Serialized DAG has changed since it is slow.
Return True if we determine that DAG still exists.
"""
latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session)
if dag_run.dag_hash == latest_version:
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
return True
dag_run.dag_hash = latest_version
# Refresh the DAG
dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session)
if not dag_run.dag:
return False
# Verify integrity also takes care of session.flush
dag_run.verify_integrity(session=session)
return True
def _send_dag_callbacks_to_processor(self, dag: DAG, callback: DagCallbackRequest | None = None) -> None:
self._send_sla_callbacks_to_processor(dag)
if callback:
self.job.executor.send_callback(callback)
else:
self.log.debug("callback is empty")
def _send_sla_callbacks_to_processor(self, dag: DAG) -> None:
"""Sends SLA Callbacks to DagFileProcessor if tasks have SLAs set and check_slas=True."""
if not settings.CHECK_SLAS:
return
if not any(isinstance(task.sla, timedelta) for task in dag.tasks):
self.log.debug("Skipping SLA check for %s because no tasks in DAG have SLAs", dag)
return
if not dag.timetable.periodic:
self.log.debug("Skipping SLA check for %s because DAG is not scheduled", dag)
return
dag_model = DagModel.get_dagmodel(dag.dag_id)
if not dag_model:
self.log.error("Couldn't find DAG %s in database!", dag.dag_id)
return
request = SlaCallbackRequest(
full_filepath=dag.fileloc,
dag_id=dag.dag_id,
processor_subdir=dag_model.processor_subdir,
)
self.job.executor.send_callback(request)
@provide_session
def _fail_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> None:
"""
Mark tasks stuck in queued for longer than `task_queued_timeout` as failed.
Tasks can get stuck in queued for a wide variety of reasons (e.g. celery loses
track of a task, a cluster can't further scale up its workers, etc.), but tasks
should not be stuck in queued for a long time. This will mark tasks stuck in
queued for longer than `self._task_queued_timeout` as failed. If the task has
available retries, it will be retried.
"""
self.log.debug("Calling SchedulerJob._fail_tasks_stuck_in_queued method")
tasks_stuck_in_queued = (
session.query(TI)
.filter(
TI.state == State.QUEUED,
TI.queued_dttm < (timezone.utcnow() - timedelta(seconds=self._task_queued_timeout)),
TI.queued_by_job_id == self.job.id,
)
.all()
)
try:
tis_for_warning_message = self.job.executor.cleanup_stuck_queued_tasks(tis=tasks_stuck_in_queued)
if tis_for_warning_message:
task_instance_str = "\n\t".join(tis_for_warning_message)
self.log.warning(
"Marked the following %s task instances stuck in queued as failed. "
"If the task instance has available retries, it will be retried.\n\t%s",
len(tasks_stuck_in_queued),
task_instance_str,
)
except NotImplementedError:
self.log.debug("Executor doesn't support cleanup of stuck queued tasks. Skipping.")
...
@provide_session
def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None:
from airflow.models.pool import Pool
pools = Pool.slots_stats(session=session)
for pool_name, slot_stats in pools.items():
Stats.gauge(f"pool.open_slots.{pool_name}", slot_stats["open"])
Stats.gauge(f"pool.queued_slots.{pool_name}", slot_stats["queued"])
Stats.gauge(f"pool.running_slots.{pool_name}", slot_stats["running"])
# tagged metrics
Stats.gauge("pool.open_slots", slot_stats["open"], tags={"pool_name": pool_name})
Stats.gauge("pool.queued_slots", slot_stats["queued"], tags={"pool_name": pool_name})
Stats.gauge("pool.running_slots", slot_stats["running"], tags={"pool_name": pool_name})
@provide_session
def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int:
"""
Reset any TaskInstance still in QUEUED or SCHEDULED states that were
enqueued by a SchedulerJob that is no longer running.
:return: the number of TIs reset
"""
self.log.info("Resetting orphaned tasks for active dag runs")
timeout = conf.getint("scheduler", "scheduler_health_check_threshold")
for attempt in run_with_db_retries(logger=self.log):
with attempt:
self.log.debug(
"Running SchedulerJob.adopt_or_reset_orphaned_tasks with retries. Try %d of %d",
attempt.retry_state.attempt_number,
MAX_DB_RETRIES,
)
self.log.debug("Calling SchedulerJob.adopt_or_reset_orphaned_tasks method")
try:
num_failed = (
session.query(Job)
.filter(
Job.job_type == "SchedulerJob",
Job.state == State.RUNNING,
Job.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)),
)
.update({"state": State.FAILED})
)
if num_failed:
self.log.info("Marked %d SchedulerJob instances as failed", num_failed)
Stats.incr(self.__class__.__name__.lower() + "_end", num_failed)
resettable_states = [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING]
query = (
session.query(TI)
.filter(TI.state.in_(resettable_states))
# outerjoin is because we didn't use to have queued_by_job
# set, so we need to pick up anything pre upgrade. This (and the
# "or queued_by_job_id IS NONE") can go as soon as scheduler HA is
# released.
.outerjoin(TI.queued_by_job)
.filter(or_(TI.queued_by_job_id.is_(None), Job.state != State.RUNNING))
.join(TI.dag_run)
.filter(
DagRun.run_type != DagRunType.BACKFILL_JOB,
DagRun.state == State.RUNNING,
)
.options(load_only(TI.dag_id, TI.task_id, TI.run_id))
)
# Lock these rows, so that another scheduler can't try and adopt these too
tis_to_reset_or_adopt = with_row_locks(
query, of=TI, session=session, **skip_locked(session=session)
).all()
to_reset = self.job.executor.try_adopt_task_instances(tis_to_reset_or_adopt)
reset_tis_message = []
for ti in to_reset:
reset_tis_message.append(repr(ti))
ti.state = State.NONE
ti.queued_by_job_id = None
for ti in set(tis_to_reset_or_adopt) - set(to_reset):
ti.queued_by_job_id = self.job.id
Stats.incr("scheduler.orphaned_tasks.cleared", len(to_reset))
Stats.incr("scheduler.orphaned_tasks.adopted", len(tis_to_reset_or_adopt) - len(to_reset))
if to_reset:
task_instance_str = "\n\t".join(reset_tis_message)
self.log.info(
"Reset the following %s orphaned TaskInstances:\n\t%s",
len(to_reset),
task_instance_str,
)
# Issue SQL/finish "Unit of Work", but let @provide_session
# commit (or if passed a session, let caller decide when to commit
session.flush()
except OperationalError:
session.rollback()
raise
return len(to_reset)
@provide_session
def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None:
"""
Looks at all tasks that are in the "deferred" state and whose trigger
or execution timeout has passed, so they can be marked as failed.
"""
num_timed_out_tasks = (
session.query(TI)
.filter(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
)
.update(
# We have to schedule these to fail themselves so it doesn't
# happen inside the scheduler.
{
"state": TaskInstanceState.SCHEDULED,
"next_method": "__fail__",
"next_kwargs": {"error": "Trigger/execution timeout"},
"trigger_id": None,
}
)
)
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
def _find_zombies(self) -> None:
"""
Find zombie task instances, which are tasks haven't heartbeated for too long
or have a no-longer-running LocalTaskJob, and create a TaskCallbackRequest
to be handled by the DAG processor.
"""
from airflow.jobs.job import Job
self.log.debug("Finding 'running' jobs without a recent heartbeat")
limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs)
with create_session() as session:
zombies: list[tuple[TI, str, str]] = (
session.query(TI, DM.fileloc, DM.processor_subdir)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(Job, TI.job_id == Job.id)
.join(DM, TI.dag_id == DM.dag_id)
.filter(TI.state == TaskInstanceState.RUNNING)
.filter(
or_(
Job.state != State.RUNNING,
Job.latest_heartbeat < limit_dttm,
)
)
.filter(Job.job_type == "LocalTaskJob")
.filter(TI.queued_by_job_id == self.job.id)
.all()
)
if zombies:
self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm)
for ti, file_loc, processor_subdir in zombies:
zombie_message_details = self._generate_zombie_message_details(ti)
request = TaskCallbackRequest(
full_filepath=file_loc,
processor_subdir=processor_subdir,
simple_task_instance=SimpleTaskInstance.from_ti(ti),
msg=str(zombie_message_details),
)
self.log.error("Detected zombie job: %s", request)
self.job.executor.send_callback(request)
Stats.incr("zombies_killed", tags={"dag_id": ti.dag_id, "task_id": ti.task_id})
@staticmethod
def _generate_zombie_message_details(ti: TI) -> dict[str, Any]:
zombie_message_details = {
"DAG Id": ti.dag_id,
"Task Id": ti.task_id,
"Run Id": ti.run_id,
}
if ti.map_index != -1:
zombie_message_details["Map Index"] = ti.map_index
if ti.hostname:
zombie_message_details["Hostname"] = ti.hostname
if ti.external_executor_id:
zombie_message_details["External Executor Id"] = ti.external_executor_id
return zombie_message_details
@provide_session
def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None:
"""
Find all dags that were not updated by Dag Processor recently and mark them as inactive.
In case one of DagProcessors is stopped (in case there are multiple of them
for different dag folders), it's dags are never marked as inactive.
Also remove dags from SerializedDag table.
Executed on schedule only if [scheduler]standalone_dag_processor is True.
"""
self.log.debug("Checking dags not parsed within last %s seconds.", self._dag_stale_not_seen_duration)
limit_lpt = timezone.utcnow() - timedelta(seconds=self._dag_stale_not_seen_duration)
stale_dags = (
session.query(DagModel).filter(DagModel.is_active, DagModel.last_parsed_time < limit_lpt).all()
)
if not stale_dags:
self.log.debug("Not stale dags found.")
return
self.log.info("Found (%d) stales dags not parsed after %s.", len(stale_dags), limit_lpt)
for dag in stale_dags:
dag.is_active = False
SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session)
session.flush()
def _set_orphaned(self, dataset: DatasetModel) -> int:
self.log.info("Orphaning unreferenced dataset '%s'", dataset.uri)
dataset.is_orphaned = expression.true()
return 1
@provide_session
def _orphan_unreferenced_datasets(self, session: Session = NEW_SESSION) -> None:
"""
Detects datasets that are no longer referenced in any DAG schedule parameters or task outlets and
sets the dataset is_orphaned flag to True.
"""
orphaned_dataset_query = (
session.query(DatasetModel)
.join(
DagScheduleDatasetReference,
isouter=True,
)
.join(
TaskOutletDatasetReference,
isouter=True,
)
# MSSQL doesn't like it when we select a column that we haven't grouped by. All other DBs let us
# group by id and select all columns.
.group_by(DatasetModel if session.get_bind().dialect.name == "mssql" else DatasetModel.id)
.having(
and_(
func.count(DagScheduleDatasetReference.dag_id) == 0,
func.count(TaskOutletDatasetReference.dag_id) == 0,
)
)
)
updated_count = sum(self._set_orphaned(dataset) for dataset in orphaned_dataset_query)
Stats.gauge("dataset.orphaned", updated_count)