blob: f0295a12b97c0b01f6b1ca367c9ace19c8337732 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
import logging
import os
import signal
import sys
import threading
import time
import warnings
from collections import deque
from contextlib import suppress
from copy import copy
from queue import SimpleQueue
from typing import TYPE_CHECKING, Deque
from sqlalchemy import func
from airflow.configuration import conf
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import Job, perform_heartbeat
from airflow.models.trigger import Trigger
from airflow.serialization.pydantic.job import JobPydantic
from airflow.stats import Stats
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.typing_compat import TypedDict
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.trigger_handler import (
DropTriggerLogsFilter,
LocalQueueHandler,
TriggererHandlerWrapper,
TriggerMetadataFilter,
ctx_indiv_trigger,
ctx_task_instance,
ctx_trigger_end,
ctx_trigger_id,
)
from airflow.utils.module_loading import import_string
from airflow.utils.session import provide_session
if TYPE_CHECKING:
from airflow.models import TaskInstance
HANDLER_SUPPORTS_TRIGGERER = False
"""
If this value is true, root handler is configured to log individual trigger messages
visible in task logs.
:meta private:
"""
SEND_TRIGGER_END_MARKER = True
"""
If handler natively supports triggers, may want to disable sending trigger end marker.
:meta private:
"""
logger = logging.getLogger(__name__)
DISABLE_WRAPPER = conf.getboolean("logging", "disable_trigger_handler_wrapper", fallback=False)
DISABLE_LISTENER = conf.getboolean("logging", "disable_trigger_handler_queue_listener", fallback=False)
def configure_trigger_log_handler():
"""
Configure logging such that each trigger logs to its own file and
can be exposed through the airflow webserver.
Generally speaking, we take the log handler configured for logger ``airflow.task``,
wrap it with TriggerHandlerWrapper, and set it as the handler for root logger.
If there already is a handler configured for the root logger
and it supports triggers, we wrap it instead.
:meta private:
"""
global HANDLER_SUPPORTS_TRIGGERER
def should_wrap(handler):
return handler.__dict__.get("trigger_should_wrap", False) or handler.__class__.__dict__.get(
"trigger_should_wrap", False
)
def should_queue(handler):
return handler.__dict__.get("trigger_should_queue", True) or handler.__class__.__dict__.get(
"trigger_should_queue", True
)
def send_trigger_end_marker(handler):
val = handler.__dict__.get("trigger_send_end_marker", None)
if val is not None:
return val
val = handler.__class__.__dict__.get("trigger_send_end_marker", None)
if val is not None:
return val
return True
def supports_triggerer(handler):
return (
should_wrap(handler)
or handler.__dict__.get("trigger_supported", False)
or handler.__class__.__dict__.get("trigger_supported", False)
)
def get_task_handler_from_logger(logger_):
for h in logger_.handlers:
if isinstance(h, FileTaskHandler) and not supports_triggerer(h):
warnings.warn(
f"Handler {h.__class__.__name__} does not support "
"individual trigger logging. Please check the release notes "
"for your provider to see if a newer version supports "
"individual trigger logging."
)
if supports_triggerer(h):
return h
def find_suitable_task_handler():
# check root logger then check airflow.task to see if a handler
# suitable for use with TriggerHandlerWrapper (has trigger_should_wrap
# attr, likely inherits from FileTaskHandler)
h = get_task_handler_from_logger(root_logger)
if not h:
# try to use handler configured from airflow task
logger.debug("No task logger configured for root logger; trying `airflow.task`.")
h = get_task_handler_from_logger(logging.getLogger("airflow.task"))
if h:
logger.debug("Using logging configuration from `airflow.task`")
if not h:
warnings.warn("Could not find log handler suitable for individual trigger logging.")
return None
return h
def filter_trigger_logs_from_other_root_handlers(new_hdlr):
# we add context vars to log records emitted for individual triggerer logging
# we want these records to be processed by our special trigger handler wrapper
# but not by any other handlers, so we filter out these messages from
# other handlers by adding DropTriggerLogsFilter
# we could consider only adding this filter to the default console logger
# so as to leave other custom handlers alone
for h in root_logger.handlers:
if h is not new_hdlr:
h.addFilter(DropTriggerLogsFilter())
def add_handler_wrapper_to_root(base_handler):
# first make sure we remove from root logger if it happens to be there
# it could have come from root or airflow.task, but we only need
# to make sure we remove from root, since messages will not flow
# through airflow.task
if base_handler in root_logger.handlers:
root_logger.removeHandler(base_handler)
logger.info("Setting up TriggererHandlerWrapper with handler %s", base_handler)
h = TriggererHandlerWrapper(base_handler=base_handler, level=base_handler.level)
# just extra cautious, checking if user manually configured it there
if h not in root_logger.handlers:
root_logger.addHandler(h)
return h
root_logger = logging.getLogger()
task_handler = find_suitable_task_handler()
if not task_handler:
return None
if TYPE_CHECKING:
assert isinstance(task_handler, FileTaskHandler)
if should_wrap(task_handler):
trigger_handler = add_handler_wrapper_to_root(task_handler)
else:
trigger_handler = copy(task_handler)
root_logger.addHandler(trigger_handler)
filter_trigger_logs_from_other_root_handlers(trigger_handler)
if send_trigger_end_marker(trigger_handler) is False:
global SEND_TRIGGER_END_MARKER
SEND_TRIGGER_END_MARKER = False
HANDLER_SUPPORTS_TRIGGERER = True
return should_queue(trigger_handler)
def setup_queue_listener():
"""
Route log messages to a queue and process them with QueueListener.
Airflow task handlers make blocking I/O calls.
We replace trigger log handlers, with LocalQueueHandler,
which sends log records to a queue.
Then we start a QueueListener in a thread, which is configured
to consume the queue and pass the records to the handlers as
originally configured. This keeps the handler I/O out of the
async event loop.
:meta private:
"""
queue = SimpleQueue()
root_logger = logging.getLogger()
handlers: list[logging.Handler] = []
queue_handler = LocalQueueHandler(queue)
queue_handler.addFilter(TriggerMetadataFilter())
root_logger.addHandler(queue_handler)
for h in root_logger.handlers[:]:
if h is not queue_handler and "pytest" not in h.__module__:
root_logger.removeHandler(h)
handlers.append(h)
this_logger = logging.getLogger(__name__)
if handlers:
this_logger.info("Setting up logging queue listener with handlers %s", handlers)
listener = logging.handlers.QueueListener(queue, *handlers, respect_handler_level=True)
listener.start()
return listener
else:
this_logger.warning("Unable to set up individual trigger logging")
return None
class TriggererJobRunner(BaseJobRunner["Job | JobPydantic"], LoggingMixin):
"""
TriggererJobRunner continuously runs active triggers in asyncio, watching
for them to fire off their events and then dispatching that information
to their dependent tasks/DAGs.
It runs as two threads:
- The main thread does DB calls/checkins
- A subthread runs all the async code
"""
job_type = "TriggererJob"
def __init__(
self,
job: Job | JobPydantic,
capacity=None,
):
super().__init__(job)
if capacity is None:
self.capacity = conf.getint("triggerer", "default_capacity", fallback=1000)
elif isinstance(capacity, int) and capacity > 0:
self.capacity = capacity
else:
raise ValueError(f"Capacity number {capacity} is invalid")
should_queue = True
if DISABLE_WRAPPER:
self.log.warning(
"Skipping trigger log configuration; disabled by param "
"`disable_trigger_handler_wrapper=True`."
)
else:
should_queue = configure_trigger_log_handler()
self.listener = None
if DISABLE_LISTENER:
self.log.warning(
"Skipping trigger logger queue listener; disabled by param "
"`disable_trigger_handler_queue_listener=True`."
)
elif should_queue is False:
self.log.warning("Skipping trigger logger queue listener; disabled by handler setting.")
else:
self.listener = setup_queue_listener()
# Set up runner async thread
self.trigger_runner = TriggerRunner()
def register_signals(self) -> None:
"""Register signals that stop child processes."""
signal.signal(signal.SIGINT, self._exit_gracefully)
signal.signal(signal.SIGTERM, self._exit_gracefully)
@classmethod
@provide_session
def is_needed(cls, session) -> bool:
"""
Tests if the triggerer job needs to be run (i.e., if there are triggers
in the trigger table).
This is used for the warning boxes in the UI.
"""
return session.query(func.count(Trigger.id)).scalar() > 0
def on_kill(self):
"""
Called when there is an external kill command (via the heartbeat
mechanism, for example).
"""
self.trigger_runner.stop = True
def _kill_listener(self):
if self.listener:
for h in self.listener.handlers:
h.close()
self.listener.stop()
def _exit_gracefully(self, signum, frame) -> None:
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
# The first time, try to exit nicely
if not self.trigger_runner.stop:
self.log.info("Exiting gracefully upon receiving signal %s", signum)
self.trigger_runner.stop = True
self._kill_listener()
else:
self.log.warning("Forcing exit due to second exit signal %s", signum)
sys.exit(os.EX_SOFTWARE)
def _execute(self) -> int | None:
self.log.info("Starting the triggerer")
try:
# set job_id so that it can be used in log file names
self.trigger_runner.job_id = self.job.id
# Kick off runner thread
self.trigger_runner.start()
# Start our own DB loop in the main thread
self._run_trigger_loop()
except Exception:
self.log.exception("Exception when executing TriggererJobRunner._run_trigger_loop")
raise
finally:
self.log.info("Waiting for triggers to clean up")
# Tell the subthread to stop and then wait for it.
# If the user interrupts/terms again, _graceful_exit will allow them
# to force-kill here.
self.trigger_runner.stop = True
self.trigger_runner.join(30)
self.log.info("Exited trigger loop")
return None
def _run_trigger_loop(self) -> None:
"""
The main-thread trigger loop.
This runs synchronously and handles all database reads/writes.
"""
while not self.trigger_runner.stop:
# Clean out unused triggers
Trigger.clean_unused()
# Load/delete triggers
self.load_triggers()
# Handle events
self.handle_events()
# Handle failed triggers
self.handle_failed_triggers()
perform_heartbeat(self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True)
# Collect stats
self.emit_metrics()
# Idle sleep
time.sleep(1)
def load_triggers(self):
"""
Queries the database to get the triggers we're supposed to be running,
adds them to our runner, and then removes ones from it we no longer
need.
"""
Trigger.assign_unassigned(self.job.id, self.capacity)
ids = Trigger.ids_for_triggerer(self.job.id)
self.trigger_runner.update_triggers(set(ids))
def handle_events(self):
"""
Handles outbound events from triggers - dispatching them into the Trigger
model where they are then pushed into the relevant task instances.
"""
while self.trigger_runner.events:
# Get the event and its trigger ID
trigger_id, event = self.trigger_runner.events.popleft()
# Tell the model to wake up its tasks
Trigger.submit_event(trigger_id=trigger_id, event=event)
# Emit stat event
Stats.incr("triggers.succeeded")
def handle_failed_triggers(self):
"""
Handles "failed" triggers - ones that errored or exited before they
sent an event. Task Instances that depend on them need failing.
"""
while self.trigger_runner.failed_triggers:
# Tell the model to fail this trigger's deps
trigger_id, saved_exc = self.trigger_runner.failed_triggers.popleft()
Trigger.submit_failure(trigger_id=trigger_id, exc=saved_exc)
# Emit stat event
Stats.incr("triggers.failed")
def emit_metrics(self):
Stats.gauge("triggers.running", len(self.trigger_runner.triggers))
class TriggerDetails(TypedDict):
"""Type class for the trigger details dictionary."""
task: asyncio.Task
name: str
events: int
class TriggerRunner(threading.Thread, LoggingMixin):
"""
Runtime environment for all triggers.
Mainly runs inside its own thread, where it hands control off to an asyncio
event loop, but is also sometimes interacted with from the main thread
(where all the DB queries are done). All communication between threads is
done via Deques.
"""
# Maps trigger IDs to their running tasks and other info
triggers: dict[int, TriggerDetails]
# Cache for looking up triggers by classpath
trigger_cache: dict[str, type[BaseTrigger]]
# Inbound queue of new triggers
to_create: Deque[tuple[int, BaseTrigger]]
# Inbound queue of deleted triggers
to_cancel: Deque[int]
# Outbound queue of events
events: Deque[tuple[int, TriggerEvent]]
# Outbound queue of failed triggers
failed_triggers: Deque[tuple[int, BaseException]]
# Should-we-stop flag
stop: bool = False
def __init__(self):
super().__init__()
self.triggers = {}
self.trigger_cache = {}
self.to_create = deque()
self.to_cancel = deque()
self.events = deque()
self.failed_triggers = deque()
self.job_id = None
def run(self):
"""Sync entrypoint - just runs arun in an async loop."""
asyncio.run(self.arun())
async def arun(self):
"""
Main (asynchronous) logic loop.
The loop in here runs trigger addition/deletion/cleanup. Actual
triggers run in their own separate coroutines.
"""
watchdog = asyncio.create_task(self.block_watchdog())
last_status = time.time()
while not self.stop:
# Run core logic
await self.create_triggers()
await self.cancel_triggers()
await self.cleanup_finished_triggers()
# Sleep for a bit
await asyncio.sleep(1)
# Every minute, log status
if time.time() - last_status >= 60:
count = len(self.triggers)
self.log.info("%i triggers currently running", count)
last_status = time.time()
# Wait for watchdog to complete
await watchdog
async def create_triggers(self):
"""
Drain the to_create queue and create all triggers that have been
requested in the DB that we don't yet have.
"""
while self.to_create:
trigger_id, trigger_instance = self.to_create.popleft()
if trigger_id not in self.triggers:
task_instance: TaskInstance = trigger_instance.task_instance
dag_id = task_instance.dag_id
run_id = task_instance.run_id
task_id = task_instance.task_id
map_index = task_instance.map_index
try_number = task_instance.try_number
self.triggers[trigger_id] = {
"task": asyncio.create_task(self.run_trigger(trigger_id, trigger_instance)),
"name": f"{dag_id}/{run_id}/{task_id}/{map_index}/{try_number} (ID {trigger_id})",
"events": 0,
}
else:
self.log.warning("Trigger %s had insertion attempted twice", trigger_id)
await asyncio.sleep(0)
async def cancel_triggers(self):
"""
Drain the to_cancel queue and ensure all triggers that are not in the
DB are cancelled, so the cleanup job deletes them.
"""
while self.to_cancel:
trigger_id = self.to_cancel.popleft()
if trigger_id in self.triggers:
# We only delete if it did not exit already
self.triggers[trigger_id]["task"].cancel()
await asyncio.sleep(0)
async def cleanup_finished_triggers(self):
"""
Go through all trigger tasks (coroutines) and clean up entries for
ones that have exited, optionally warning users if the exit was
not normal.
"""
for trigger_id, details in list(self.triggers.items()):
if details["task"].done():
# Check to see if it exited for good reasons
saved_exc = None
try:
result = details["task"].result()
except (asyncio.CancelledError, SystemExit, KeyboardInterrupt):
# These are "expected" exceptions and we stop processing here
# If we don't, then the system requesting a trigger be removed -
# which turns into CancelledError - results in a failure.
del self.triggers[trigger_id]
continue
except BaseException as e:
# This is potentially bad, so log it.
self.log.exception("Trigger %s exited with error %s", details["name"], e)
saved_exc = e
else:
# See if they foolishly returned a TriggerEvent
if isinstance(result, TriggerEvent):
self.log.error(
"Trigger %s returned a TriggerEvent rather than yielding it", details["name"]
)
# See if this exited without sending an event, in which case
# any task instances depending on it need to be failed
if details["events"] == 0:
self.log.error(
"Trigger %s exited without sending an event. Dependent tasks will be failed.",
details["name"],
)
self.failed_triggers.append((trigger_id, saved_exc))
del self.triggers[trigger_id]
await asyncio.sleep(0)
async def block_watchdog(self):
"""
Watchdog loop that detects blocking (badly-written) triggers.
Triggers should be well-behaved async coroutines and await whenever
they need to wait; this loop tries to run every 100ms to see if
there are badly-written triggers taking longer than that and blocking
the event loop.
Unfortunately, we can't tell what trigger is blocking things, but
we can at least detect the top-level problem.
"""
while not self.stop:
last_run = time.monotonic()
await asyncio.sleep(0.1)
# We allow a generous amount of buffer room for now, since it might
# be a busy event loop.
time_elapsed = time.monotonic() - last_run
if time_elapsed > 0.2:
self.log.error(
"Triggerer's async thread was blocked for %.2f seconds, "
"likely by a badly-written trigger. Set PYTHONASYNCIODEBUG=1 "
"to get more information on overrunning coroutines.",
time_elapsed,
)
Stats.incr("triggers.blocked_main_thread")
@staticmethod
def set_individual_trigger_logging(trigger):
"""
Setting these context vars allows log messages for individual triggers
to be routed to distinct files and filtered from triggerer stdout.
"""
# set logging context vars for routing to appropriate handler
ctx_task_instance.set(trigger.task_instance)
ctx_trigger_id.set(trigger.trigger_id)
ctx_trigger_end.set(False)
# mark that we're in the context of an individual trigger so log records can be filtered
ctx_indiv_trigger.set(True)
async def run_trigger(self, trigger_id, trigger):
"""
Wrapper which runs an actual trigger (they are async generators)
and pushes their events into our outbound event deque.
"""
name = self.triggers[trigger_id]["name"]
self.log.info("trigger %s starting", name)
try:
self.set_individual_trigger_logging(trigger)
async for event in trigger.run():
self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]["name"], event)
self.triggers[trigger_id]["events"] += 1
self.events.append((trigger_id, event))
finally:
# CancelledError will get injected when we're stopped - which is
# fine, the cleanup process will understand that, but we want to
# allow triggers a chance to cleanup, either in that case or if
# they exit cleanly. Exception from cleanup methods are ignored.
with suppress(Exception):
await trigger.cleanup()
if SEND_TRIGGER_END_MARKER:
self.mark_trigger_end(trigger)
# unsetting ctx_indiv_trigger var restores stdout logging
ctx_indiv_trigger.set(None)
self.log.info("trigger %s completed", name)
@staticmethod
def mark_trigger_end(trigger):
if not HANDLER_SUPPORTS_TRIGGERER:
return
ctx_trigger_end.set(True)
trigger.log.log(level=100, msg="trigger end")
def update_triggers(self, requested_trigger_ids: set[int]):
"""
Called from the main thread to request that we update what
triggers we're running.
Works out the differences - ones to add, and ones to remove - then
adds them to the deques so the subthread can actually mutate the running
trigger set.
"""
# Note that `triggers` could be mutated by the other thread during this
# line's execution, but we consider that safe, since there's a strict
# add -> remove -> never again lifecycle this function is already
# handling.
running_trigger_ids = set(self.triggers.keys())
known_trigger_ids = (
running_trigger_ids.union(x[0] for x in self.events)
.union(self.to_cancel)
.union(x[0] for x in self.to_create)
.union(trigger[0] for trigger in self.failed_triggers)
)
# Work out the two difference sets
new_trigger_ids = requested_trigger_ids - known_trigger_ids
cancel_trigger_ids = running_trigger_ids - requested_trigger_ids
# Bulk-fetch new trigger records
new_triggers = Trigger.bulk_fetch(new_trigger_ids)
# Add in new triggers
for new_id in new_trigger_ids:
# Check it didn't vanish in the meantime
if new_id not in new_triggers:
self.log.warning("Trigger ID %s disappeared before we could start it", new_id)
continue
# Resolve trigger record into an actual class instance
try:
new_trigger_orm = new_triggers[new_id]
trigger_class = self.get_trigger_by_classpath(new_trigger_orm.classpath)
except BaseException as e:
# Either the trigger code or the path to it is bad. Fail the trigger.
self.failed_triggers.append((new_id, e))
continue
new_trigger_instance = trigger_class(**new_trigger_orm.kwargs)
self.set_trigger_logging_metadata(new_trigger_orm.task_instance, new_id, new_trigger_instance)
self.to_create.append((new_id, new_trigger_instance))
# Enqueue orphaned triggers for cancellation
for old_id in cancel_trigger_ids:
self.to_cancel.append(old_id)
def set_trigger_logging_metadata(self, ti: TaskInstance, trigger_id, trigger):
"""
Set up logging for triggers.
We want to ensure that each trigger logs to its own file and that the log messages are not
propagated to parent loggers.
:meta private:
"""
if ti: # can be None in tests
ti.is_trigger_log_context = True
trigger.task_instance = ti
trigger.triggerer_job_id = self.job_id
trigger.trigger_id = trigger_id
def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]:
"""
Gets a trigger class by its classpath ("path.to.module.classname").
Uses a cache dictionary to speed up lookups after the first time.
"""
if classpath not in self.trigger_cache:
self.trigger_cache[classpath] = import_string(classpath)
return self.trigger_cache[classpath]