blob: 9599adabdfaa4856cc50dba046ef513ce4c1d74c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Base executor - this is the base class for all the implemented executors."""
from __future__ import annotations
import logging
import sys
import warnings
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple
import pendulum
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
PARALLELISM: int = conf.getint("core", "PARALLELISM")
if TYPE_CHECKING:
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
# Command to execute - list of strings
# the first element is always "airflow".
# It should be result of TaskInstance.generate_command method.
CommandType = List[str]
# Task that is queued. It contains all the information that is
# needed to run the task.
#
# Tuple of: command, priority, queue name, TaskInstance
QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], TaskInstance]
# Event_buffer dict value type
# Tuple of: state, info
EventBufferValueType = Tuple[Optional[str], Any]
# Task tuple to send to be executed
TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]]
log = logging.getLogger(__name__)
@dataclass
class RunningRetryAttemptType:
"""
For keeping track of attempts to queue again when task still apparently running.
We don't want to slow down the loop, so we don't block, but we allow it to be
re-checked for at least MIN_SECONDS seconds.
"""
MIN_SECONDS = 10
total_tries: int = field(default=0, init=False)
tries_after_min: int = field(default=0, init=False)
first_attempt_time: datetime = field(default_factory=lambda: pendulum.now("UTC"), init=False)
@property
def elapsed(self):
"""Seconds since first attempt."""
return (pendulum.now("UTC") - self.first_attempt_time).total_seconds()
def can_try_again(self):
"""
If there has been at least one try greater than MIN_SECONDS after first attempt,
then return False. Otherwise, return True.
"""
if self.tries_after_min > 0:
return False
self.total_tries += 1
elapsed = self.elapsed
if elapsed > self.MIN_SECONDS:
self.tries_after_min += 1
log.debug("elapsed=%s tries=%s", elapsed, self.total_tries)
return True
class BaseExecutor(LoggingMixin):
"""
Class to derive in order to implement concrete executors.
Such as, Celery, Kubernetes, Local, Sequential and the likes.
:param parallelism: how many jobs should run at one time. Set to
``0`` for infinity.
"""
supports_ad_hoc_ti_run: bool = False
supports_pickling: bool = True
supports_sentry: bool = False
is_local: bool = False
is_single_threaded: bool = False
is_production: bool = True
change_sensor_mode_to_reschedule: bool = False
serve_logs: bool = False
job_id: None | int | str = None
callback_sink: BaseCallbackSink | None = None
def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
self.parallelism: int = parallelism
self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict()
self.running: set[TaskInstanceKey] = set()
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType)
def __repr__(self):
return f"{self.__class__.__name__}(parallelism={self.parallelism})"
def start(self): # pragma: no cover
"""Executors may need to get things started."""
def queue_command(
self,
task_instance: TaskInstance,
command: CommandType,
priority: int = 1,
queue: str | None = None,
):
"""Queues command to task."""
if task_instance.key not in self.queued_tasks:
self.log.info("Adding to queue: %s", command)
self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance)
else:
self.log.error("could not queue task %s", task_instance.key)
def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
pickle_id: int | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
wait_for_past_depends_before_skipping: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
pool: str | None = None,
cfg_path: str | None = None,
) -> None:
"""Queues task instance."""
pool = pool or task_instance.pool
# TODO (edgarRd): AIRFLOW-1985:
# cfg_path is needed to propagate the config values if using impersonation
# (run_as_user), given that there are different code paths running tasks.
# For a long term solution we need to address AIRFLOW-1986
command_list_to_run = task_instance.command_as_list(
local=True,
mark_success=mark_success,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
pool=pool,
pickle_id=pickle_id,
cfg_path=cfg_path,
)
self.log.debug("created command %s", command_list_to_run)
self.queue_command(
task_instance,
command_list_to_run,
priority=task_instance.task.priority_weight_total,
queue=task_instance.task.queue,
)
def has_task(self, task_instance: TaskInstance) -> bool:
"""
Checks if a task is either queued or running in this executor.
:param task_instance: TaskInstance
:return: True if the task is known to this executor
"""
return task_instance.key in self.queued_tasks or task_instance.key in self.running
def sync(self) -> None:
"""
Sync will get called periodically by the heartbeat method.
Executors should override this to perform gather statuses.
"""
def heartbeat(self) -> None:
"""Heartbeat sent to trigger new jobs."""
if not self.parallelism:
open_slots = len(self.queued_tasks)
else:
open_slots = self.parallelism - len(self.running)
num_running_tasks = len(self.running)
num_queued_tasks = len(self.queued_tasks)
self.log.debug("%s running task instances", num_running_tasks)
self.log.debug("%s in queue", num_queued_tasks)
self.log.debug("%s open slots", open_slots)
Stats.gauge(
"executor.open_slots", value=open_slots, tags={"status": "open", "name": self.__class__.__name__}
)
Stats.gauge(
"executor.queued_tasks",
value=num_queued_tasks,
tags={"status": "queued", "name": self.__class__.__name__},
)
Stats.gauge(
"executor.running_tasks",
value=num_running_tasks,
tags={"status": "running", "name": self.__class__.__name__},
)
self.trigger_tasks(open_slots)
# Calling child class sync method
self.log.debug("Calling the %s sync method", self.__class__)
self.sync()
def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTaskInstanceType]]:
"""
Orders the queued tasks by priority.
:return: List of tuples from the queued_tasks according to the priority.
"""
return sorted(
self.queued_tasks.items(),
key=lambda x: x[1][1],
reverse=True,
)
def trigger_tasks(self, open_slots: int) -> None:
"""
Initiates async execution of the queued tasks, up to the number of available slots.
:param open_slots: Number of open slots
"""
sorted_queue = self.order_queued_tasks_by_priority()
task_tuples = []
for _ in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, ti) = sorted_queue.pop(0)
# If a task makes it here but is still understood by the executor
# to be running, it generally means that the task has been killed
# externally and not yet been marked as failed.
#
# However, when a task is deferred, there is also a possibility of
# a race condition where a task might be scheduled again during
# trigger processing, even before we are able to register that the
# deferred task has completed. In this case and for this reason,
# we make a small number of attempts to see if the task has been
# removed from the running set in the meantime.
if key in self.running:
attempt = self.attempts[key]
if attempt.can_try_again():
# if it hasn't been much time since first check, let it be checked again next time
self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key)
continue
# Otherwise, we give up and remove the task from the queue.
self.log.error(
"could not queue task %s (still running after %d attempts)", key, attempt.total_tries
)
del self.attempts[key]
del self.queued_tasks[key]
else:
if key in self.attempts:
del self.attempts[key]
task_tuples.append((key, command, queue, ti.executor_config))
if task_tuples:
self._process_tasks(task_tuples)
def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
for key, command, queue, executor_config in task_tuples:
del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
"""
Changes state of the task.
:param info: Executor information for the task instance
:param key: Unique key for the task instance
:param state: State to set for the task.
"""
self.log.debug("Changing state: %s", key)
try:
self.running.remove(key)
except KeyError:
self.log.debug("Could not find key: %s", str(key))
self.event_buffer[key] = state, info
def fail(self, key: TaskInstanceKey, info=None) -> None:
"""
Set fail state for the event.
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, State.FAILED, info)
def success(self, key: TaskInstanceKey, info=None) -> None:
"""
Set success state for the event.
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, State.SUCCESS, info)
def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Return and flush the event buffer.
In case dag_ids is specified it will only return and flush events
for the given dag_ids. Otherwise, it returns and flushes all events.
:param dag_ids: the dag_ids to return events for; returns all if given ``None``.
:return: a dict of events
"""
cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {}
if dag_ids is None:
cleared_events = self.event_buffer
self.event_buffer = {}
else:
for ti_key in list(self.event_buffer.keys()):
if ti_key.dag_id in dag_ids:
cleared_events[ti_key] = self.event_buffer.pop(ti_key)
return cleared_events
def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
queue: str | None = None,
executor_config: Any | None = None,
) -> None: # pragma: no cover
"""
This method will execute the command asynchronously.
:param key: Unique key for the task instance
:param command: Command to run
:param queue: name of the queue
:param executor_config: Configuration passed to the executor.
"""
raise NotImplementedError()
def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]:
"""
This method can be implemented by any child class to return the task logs.
:param ti: A TaskInstance object
:param try_number: current try_number to read log from
:return: tuple of logs and messages
"""
return [], []
def end(self) -> None: # pragma: no cover
"""Wait synchronously for the previously submitted job to complete."""
raise NotImplementedError()
def terminate(self):
"""This method is called when the daemon receives a SIGTERM."""
raise NotImplementedError()
def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: # pragma: no cover
"""
Handle remnants of tasks that were failed because they were stuck in queued.
Tasks can get stuck in queued. If such a task is detected, it will be marked
as `UP_FOR_RETRY` if the task instance has remaining retries or marked as `FAILED`
if it doesn't.
:param tis: List of Task Instances to clean up
:return: List of readable task instances for a warning message
"""
raise NotImplementedError()
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Try to adopt running task instances that have been abandoned by a SchedulerJob dying.
Anything that is not adopted will be cleared by the scheduler (and then become eligible for
re-scheduling)
:return: any TaskInstances that were unable to be adopted
"""
# By default, assume Executors cannot adopt tasks, so just say we failed to adopt anything.
# Subclasses can do better!
return tis
@property
def slots_available(self):
"""Number of new tasks this executor instance can accept."""
if self.parallelism:
return self.parallelism - len(self.running) - len(self.queued_tasks)
else:
return sys.maxsize
@staticmethod
def validate_command(command: list[str]) -> None:
"""
Back-compat method to Check if the command to execute is airflow command.
:param command: command to check
"""
warnings.warn(
"""
The `validate_command` method is deprecated. Please use ``validate_airflow_tasks_run_command``
""",
RemovedInAirflow3Warning,
stacklevel=2,
)
BaseExecutor.validate_airflow_tasks_run_command(command)
@staticmethod
def validate_airflow_tasks_run_command(command: list[str]) -> tuple[str | None, str | None]:
"""
Check if the command to execute is airflow command.
Returns tuple (dag_id,task_id) retrieved from the command (replaced with None values if missing)
"""
if command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')
if len(command) > 3 and "--help" not in command:
dag_id: str | None = None
task_id: str | None = None
for arg in command[3:]:
if not arg.startswith("--"):
if dag_id is None:
dag_id = arg
else:
task_id = arg
break
return dag_id, task_id
return None, None
def debug_dump(self):
"""Called in response to SIGUSR2 by the scheduler."""
self.log.info(
"executor.queued (%d)\n\t%s",
len(self.queued_tasks),
"\n\t".join(map(repr, self.queued_tasks.items())),
)
self.log.info("executor.running (%d)\n\t%s", len(self.running), "\n\t".join(map(repr, self.running)))
self.log.info(
"executor.event_buffer (%d)\n\t%s",
len(self.event_buffer),
"\n\t".join(map(repr, self.event_buffer.items())),
)
def send_callback(self, request: CallbackRequest) -> None:
"""Sends callback for execution.
Provides a default implementation which sends the callback to the `callback_sink` object.
:param request: Callback request to be executed.
"""
if not self.callback_sink:
raise ValueError("Callback sink is not ready.")
self.callback_sink.send(request)