blob: 4bd0a9f2b9365e53e22c9cfda81acd3cf3d9feb0 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""CeleryExecutor.
.. seealso::
For more information on how the CeleryExecutor works, take a look at the guide:
:doc:`/celery_executor`
"""
from __future__ import annotations
import logging
import math
import operator
import time
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from importlib.metadata import version as importlib_version
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple
from celery import states as celery_states
from packaging.version import Version
try:
from airflow.cli.cli_config import (
ARG_DAEMON,
ARG_LOG_FILE,
ARG_PID,
ARG_SKIP_SERVE_LOGS,
ARG_STDERR,
ARG_STDOUT,
ARG_VERBOSE,
ActionCommand,
Arg,
GroupCommand,
lazy_load_command,
)
except ImportError:
import importlib.metadata
import packaging.version
from airflow.exceptions import AirflowOptionalProviderFeatureException
airflow_version = importlib.metadata.version("apache-airflow")
base_version = packaging.version.parse(airflow_version).base_version
if packaging.version.parse(base_version) < packaging.version.parse("2.7.0"):
raise AirflowOptionalProviderFeatureException(
"Celery Executor from Celery Provider should only be used with Airflow 2.7.0+.\n"
f"This is Airflow {airflow_version} and Celery and CeleryKubernetesExecutor are "
f"available in the 'airflow.executors' package. You should not use "
f"the provider's executors in this version of Airflow."
)
raise
from airflow.configuration import conf
from airflow.exceptions import AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task"
if TYPE_CHECKING:
import argparse
from celery import Task
from airflow.executors.base_executor import CommandType, TaskTuple
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
# Task instance that is sent over Celery queues
# TaskInstanceKey, Command, queue_name, CallableTask
TaskInstanceInCelery = Tuple[TaskInstanceKey, CommandType, Optional[str], Task]
# PEP562
def __getattr__(name):
# This allows us to make the Celery app accessible through the
# celery_executor module without the time cost of its import and
# construction
if name == "app":
from airflow.providers.celery.executors.celery_executor_utils import app
return app
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
"""
To start the celery worker, run the command:
airflow celery worker
"""
# flower cli args
ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API")
ARG_FLOWER_HOSTNAME = Arg(
("-H", "--hostname"),
default=conf.get("celery", "FLOWER_HOST"),
help="Set the hostname on which to run the server",
)
ARG_FLOWER_PORT = Arg(
("-p", "--port"),
default=conf.getint("celery", "FLOWER_PORT"),
type=int,
help="The port on which to run the server",
)
ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower")
ARG_FLOWER_URL_PREFIX = Arg(
("-u", "--url-prefix"),
default=conf.get("celery", "FLOWER_URL_PREFIX"),
help="URL prefix for Flower",
)
ARG_FLOWER_BASIC_AUTH = Arg(
("-A", "--basic-auth"),
default=conf.get("celery", "FLOWER_BASIC_AUTH"),
help=(
"Securing Flower with Basic Authentication. "
"Accepts user:password pairs separated by a comma. "
"Example: flower_basic_auth = user1:password1,user2:password2"
),
)
# worker cli args
ARG_AUTOSCALE = Arg(("-a", "--autoscale"), help="Minimum and Maximum number of worker to autoscale")
ARG_QUEUES = Arg(
("-q", "--queues"),
help="Comma delimited list of queues to serve",
default=conf.get("operators", "DEFAULT_QUEUE"),
)
ARG_CONCURRENCY = Arg(
("-c", "--concurrency"),
type=int,
help="The number of worker processes",
default=conf.getint("celery", "worker_concurrency"),
)
ARG_CELERY_HOSTNAME = Arg(
("-H", "--celery-hostname"),
help="Set the hostname of celery worker if you have multiple workers on a single machine",
)
ARG_UMASK = Arg(
("-u", "--umask"),
help="Set the umask of celery worker in daemon mode",
)
ARG_WITHOUT_MINGLE = Arg(
("--without-mingle",),
default=False,
help="Don't synchronize with other workers at start-up",
action="store_true",
)
ARG_WITHOUT_GOSSIP = Arg(
("--without-gossip",),
default=False,
help="Don't subscribe to other workers events",
action="store_true",
)
AIRFLOW_VERSION = Version(importlib_version("apache-airflow"))
CELERY_CLI_COMMAND_PATH = (
"airflow.providers.celery.cli.celery_command"
if AIRFLOW_VERSION >= Version("2.8.0")
else "airflow.cli.commands.celery_command"
)
CELERY_COMMANDS = (
ActionCommand(
name="worker",
help="Start a Celery worker node",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.worker"),
args=(
ARG_QUEUES,
ARG_CONCURRENCY,
ARG_CELERY_HOSTNAME,
ARG_PID,
ARG_DAEMON,
ARG_UMASK,
ARG_STDOUT,
ARG_STDERR,
ARG_LOG_FILE,
ARG_AUTOSCALE,
ARG_SKIP_SERVE_LOGS,
ARG_WITHOUT_MINGLE,
ARG_WITHOUT_GOSSIP,
ARG_VERBOSE,
),
),
ActionCommand(
name="flower",
help="Start a Celery Flower",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.flower"),
args=(
ARG_FLOWER_HOSTNAME,
ARG_FLOWER_PORT,
ARG_FLOWER_CONF,
ARG_FLOWER_URL_PREFIX,
ARG_FLOWER_BASIC_AUTH,
ARG_BROKER_API,
ARG_PID,
ARG_DAEMON,
ARG_STDOUT,
ARG_STDERR,
ARG_LOG_FILE,
ARG_VERBOSE,
),
),
ActionCommand(
name="stop",
help="Stop the Celery worker gracefully",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.stop_worker"),
args=(ARG_PID, ARG_VERBOSE),
),
)
class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow.
It allows distributing the execution of task instances to multiple worker nodes.
Celery is a simple, flexible and reliable distributed system to process
vast amounts of messages, while providing operations with the tools
required to maintain such a system.
"""
supports_ad_hoc_ti_run: bool = True
supports_sentry: bool = True
def __init__(self):
super().__init__()
# Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)
# so we use a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
self._sync_parallelism = conf.getint("celery", "SYNC_PARALLELISM")
if self._sync_parallelism == 0:
self._sync_parallelism = max(1, cpu_count() - 1)
from airflow.providers.celery.executors.celery_executor_utils import BulkStateFetcher
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
self.tasks = {}
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
self.task_publish_max_retries = conf.getint("celery", "task_publish_max_retries")
def start(self) -> None:
self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism)
def _num_tasks_per_send_process(self, to_send_count: int) -> int:
"""
How many Celery tasks should each worker process send.
:return: Number of tasks that should be sent per process
"""
return max(1, math.ceil(to_send_count / self._sync_parallelism))
def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
from airflow.providers.celery.executors.celery_executor_utils import execute_command
task_tuples_to_send = [task_tuple[:3] + (execute_command,) for task_tuple in task_tuples]
first_task = next(t[3] for t in task_tuples_to_send)
# Celery state queries will stuck if we do not use one same backend
# for all tasks.
cached_celery_backend = first_task.backend
key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
self.log.debug("Sent all tasks.")
from airflow.providers.celery.executors.celery_executor_utils import ExceptionWithTraceback
for key, _, result in key_and_async_results:
if isinstance(result, ExceptionWithTraceback) and isinstance(
result.exception, AirflowTaskTimeout
):
retries = self.task_publish_retries[key]
if retries < self.task_publish_max_retries:
Stats.incr("celery.task_timeout_error")
self.log.info(
"[Try %s of %s] Task Timeout Error for Task: (%s).",
self.task_publish_retries[key] + 1,
self.task_publish_max_retries,
key,
)
self.task_publish_retries[key] = retries + 1
continue
self.queued_tasks.pop(key)
self.task_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback)
self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
self.tasks[key] = result
# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
# which point we don't need the ID anymore anyway
self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)
# If the task runs _really quickly_ we may already have a result!
self.update_task_state(key, result.state, getattr(result, "info", None))
def _send_tasks_to_celery(self, task_tuples_to_send: list[TaskInstanceInCelery]):
from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor
if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
# One tuple, or max one process -> send it in the main thread.
return list(map(send_task_to_executor, task_tuples_to_send))
# Use chunks instead of a work queue to reduce context switching
# since tasks are roughly uniform in size
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
key_and_async_results = list(
send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize)
)
return key_and_async_results
def sync(self) -> None:
if not self.tasks:
self.log.debug("No task to query celery, skipping sync")
return
self.update_all_task_states()
def debug_dump(self) -> None:
"""Debug dump; called in response to SIGUSR2 by the scheduler."""
super().debug_dump()
self.log.info(
"executor.tasks (%d)\n\t%s", len(self.tasks), "\n\t".join(map(repr, self.tasks.items()))
)
def update_all_task_states(self) -> None:
"""Update states of the tasks."""
self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.tasks.values())
self.log.debug("Inquiries completed.")
for key, async_result in list(self.tasks.items()):
state, info = state_and_info_by_celery_task_id.get(async_result.task_id)
if state:
self.update_task_state(key, state, info)
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
try:
super().change_state(key, state, info, remove_running=remove_running)
except AttributeError:
# Earlier versions of the BaseExecutor don't accept the remove_running parameter for this method
super().change_state(key, state, info)
self.tasks.pop(key, None)
def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
"""Update state of a single task."""
try:
if state == celery_states.SUCCESS:
self.success(key, info)
elif state in (celery_states.FAILURE, celery_states.REVOKED):
self.fail(key, info)
elif state in (celery_states.STARTED, celery_states.PENDING):
pass
else:
self.log.info("Unexpected state for %s: %s", key, state)
except Exception:
self.log.exception("Error syncing the Celery executor, ignoring it.")
def end(self, synchronous: bool = False) -> None:
if synchronous:
while any(task.state not in celery_states.READY_STATES for task in self.tasks.values()):
time.sleep(5)
self.sync()
def terminate(self):
pass
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
# See which of the TIs are still alive (or have finished even!)
#
# Since Celery doesn't store "SENT" state for queued commands (if we create an AsyncResult with a made
# up id it just returns PENDING state for it), we have to store Celery's task_id against the TI row to
# look at in future.
#
# This process is not perfect -- we could have sent the task to celery, and crashed before we were
# able to record the AsyncResult.task_id in the TaskInstance table, in which case we won't adopt the
# task (it'll either run and update the TI state, or the scheduler will clear and re-queue it. Either
# way it won't get executed more than once)
#
# (If we swapped it around, and generated a task_id for Celery, stored that in TI and enqueued that
# there is also still a race condition where we could generate and store the task_id, but die before
# we managed to enqueue the command. Since neither way is perfect we always have to deal with this
# process not being perfect.)
from celery.result import AsyncResult
celery_tasks = {}
not_adopted_tis = []
for ti in tis:
if ti.external_executor_id is not None:
celery_tasks[ti.external_executor_id] = (AsyncResult(ti.external_executor_id), ti)
else:
not_adopted_tis.append(ti)
if not celery_tasks:
# Nothing to adopt
return tis
states_by_celery_task_id = self.bulk_state_fetcher.get_many(
list(map(operator.itemgetter(0), celery_tasks.values()))
)
adopted = []
cached_celery_backend = next(iter(celery_tasks.values()))[0].backend
for celery_task_id, (state, info) in states_by_celery_task_id.items():
result, ti = celery_tasks[celery_task_id]
result.backend = cached_celery_backend
# Set the correct elements of the state dicts, then update this
# like we just queried it.
self.tasks[ti.key] = result
self.running.add(ti.key)
self.update_task_state(ti.key, state, info)
adopted.append(f"{ti} in state {state}")
if adopted:
task_instance_str = "\n\t".join(adopted)
self.log.info(
"Adopted the following %d tasks from a dead executor\n\t%s", len(adopted), task_instance_str
)
return not_adopted_tis
def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
"""
Handle remnants of tasks that were failed because they were stuck in queued.
Tasks can get stuck in queued. If such a task is detected, it will be marked
as `UP_FOR_RETRY` if the task instance has remaining retries or marked as `FAILED`
if it doesn't.
:param tis: List of Task Instances to clean up
:return: List of readable task instances for a warning message
"""
readable_tis = []
from airflow.providers.celery.executors.celery_executor_utils import app
for ti in tis:
readable_tis.append(repr(ti))
task_instance_key = ti.key
self.fail(task_instance_key, None)
celery_async_result = self.tasks.pop(task_instance_key, None)
if celery_async_result:
try:
app.control.revoke(celery_async_result.task_id)
except Exception as ex:
self.log.error("Error revoking task instance %s from celery: %s", task_instance_key, ex)
return readable_tis
@staticmethod
def get_cli_commands() -> list[GroupCommand]:
return [
GroupCommand(
name="celery",
help="Celery components",
description=(
"Start celery components. Works only when using CeleryExecutor. For more information, "
"see https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html"
),
subcommands=CELERY_COMMANDS,
),
]
def _get_parser() -> argparse.ArgumentParser:
"""
Generate documentation; used by Sphinx.
:meta private:
"""
return CeleryExecutor._get_parser()