blob: df9699806fb4c28e7dccb656f60375773189f23b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
import datetime
import importlib
import time
from threading import Thread
from unittest.mock import MagicMock, patch
import aiofiles
import pendulum
import pytest
from airflow.config_templates import airflow_local_settings
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner, TriggerRunner, setup_queue_listener
from airflow.logging_config import configure_logging
from airflow.models import DagModel, DagRun, TaskInstance, Trigger
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.triggers.testing import FailureTrigger, SuccessTrigger
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import RedirectStdHandler
from airflow.utils.log.trigger_handler import LocalQueueHandler
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.types import DagRunType
from tests.core.test_logging_config import reset_logging
from tests.test_utils.db import clear_db_dags, clear_db_runs
pytestmark = pytest.mark.db_test
class TimeDeltaTrigger_(TimeDeltaTrigger):
def __init__(self, delta, filename):
super().__init__(delta=delta)
self.filename = filename
self.delta = delta
async def run(self):
async with aiofiles.open(self.filename, mode="a") as f:
await f.write("hi\n")
async for event in super().run():
yield event
def serialize(self):
return (
"tests.jobs.test_triggerer_job.TimeDeltaTrigger_",
{"delta": self.delta, "filename": self.filename},
)
@pytest.fixture(autouse=True)
def clean_database():
"""Fixture that cleans the database before and after every test."""
clear_db_runs()
clear_db_dags()
yield # Test runs here
clear_db_dags()
clear_db_runs()
@pytest.fixture
def session():
"""Fixture that provides a SQLAlchemy session"""
with create_session() as session:
yield session
def create_trigger_in_db(session, trigger, operator=None):
dag_model = DagModel(dag_id="test_dag")
dag = DAG(dag_id=dag_model.dag_id, start_date=pendulum.datetime(2023, 1, 1))
run = DagRun(
dag_id=dag_model.dag_id,
run_id="test_run",
execution_date=pendulum.datetime(2023, 1, 1),
run_type=DagRunType.MANUAL,
)
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
if operator:
operator.dag = dag
else:
operator = BaseOperator(task_id="test_ti", dag=dag)
task_instance = TaskInstance(operator, execution_date=run.execution_date, run_id=run.run_id)
task_instance.trigger_id = trigger_orm.id
session.add(dag_model)
session.add(run)
session.add(trigger_orm)
session.add(task_instance)
session.commit()
return dag_model, run, trigger_orm, task_instance
def test_trigger_logging_sensitive_info(session, caplog):
"""
Checks that when a trigger fires, it doesn't log any sensitive
information from arguments
"""
class SensitiveArgOperator(BaseOperator):
def __init__(self, password, **kwargs):
self.password = password
super().__init__(**kwargs)
# Use a trigger that will immediately succeed
trigger = SuccessTrigger()
op = SensitiveArgOperator(task_id="sensitive_arg_task", password="some_password")
create_trigger_in_db(session, trigger, operator=op)
triggerer_job = Job()
triggerer_job_runner = TriggererJobRunner(triggerer_job)
triggerer_job_runner.load_triggers()
# Now, start TriggerRunner up (and set it as a daemon thread during tests)
triggerer_job_runner.daemon = True
triggerer_job_runner.trigger_runner.start()
try:
# Wait for up to 3 seconds for it to fire and appear in the event queue
for _ in range(30):
if triggerer_job_runner.trigger_runner.events:
assert list(triggerer_job_runner.trigger_runner.events) == [(1, TriggerEvent(True))]
break
time.sleep(0.1)
else:
pytest.fail("TriggerRunner never sent the trigger event out")
finally:
# We always have to stop the runner
triggerer_job_runner.trigger_runner.stop = True
triggerer_job_runner.trigger_runner.join(30)
# Since we have now an in-memory process of forwarding the logs to stdout,
# give it more time for the trigger event to write the log.
time.sleep(0.5)
assert "test_dag/test_run/sensitive_arg_task/-1/0 (ID 1) starting" in caplog.text
assert "some_password" not in caplog.text
def test_is_alive():
"""Checks the heartbeat logic"""
# Current time
triggerer_job = Job(heartrate=10, state=State.RUNNING)
assert triggerer_job.is_alive()
# Slightly old, but still fresh
triggerer_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=20)
assert triggerer_job.is_alive()
# Old enough to fail
triggerer_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=31)
assert not triggerer_job.is_alive()
# Completed state should not be alive
triggerer_job.state = State.SUCCESS
triggerer_job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10)
assert not triggerer_job.is_alive(), "Completed jobs even with recent heartbeat should not be alive"
def test_is_needed(session):
"""Checks the triggerer-is-needed logic"""
# No triggers, no need
triggerer_job = Job(heartrate=10, state=State.RUNNING)
triggerer_job_runner = TriggererJobRunner(triggerer_job)
assert triggerer_job_runner.is_needed() is False
# Add a trigger, it's needed
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
session.add(trigger_orm)
session.commit()
assert triggerer_job_runner.is_needed() is True
def test_capacity_decode():
"""
Tests that TriggererJob correctly sets capacity to a valid value passed in as a CLI arg,
handles invalid args, or sets it to a default value if no arg is passed.
"""
# Positive cases
variants = [
42,
None,
]
for input_str in variants:
job = Job()
job_runner = TriggererJobRunner(job, capacity=input_str)
assert job_runner.capacity == input_str or job_runner.capacity == 1000
# Negative cases
variants = [
"NAN",
0.5,
-42,
4 / 2, # Resolves to a float, in addition to being just plain weird
]
for input_str in variants:
job = Job()
with pytest.raises(ValueError):
TriggererJobRunner(job=job, capacity=input_str)
def test_trigger_lifecycle(session):
"""
Checks that the triggerer will correctly see a new Trigger in the database
and send it to the trigger runner, and then delete it when it vanishes.
"""
# Use a trigger that will not fire for the lifetime of the test
# (we want to avoid it firing and deleting itself)
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, trigger)
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
job_runner = TriggererJobRunner(job)
job_runner.load_triggers()
# Make sure it turned up in TriggerRunner's queue
assert [x for x, y in job_runner.trigger_runner.to_create] == [1]
# Now, start TriggerRunner up (and set it as a daemon thread during tests)
job_runner.daemon = True
job_runner.trigger_runner.start()
try:
# Wait for up to 3 seconds for it to appear in the TriggerRunner's storage
for _ in range(30):
if job_runner.trigger_runner.triggers:
assert list(job_runner.trigger_runner.triggers.keys()) == [1]
break
time.sleep(0.1)
else:
pytest.fail("TriggerRunner never created trigger")
# OK, now remove it from the DB
session.delete(trigger_orm)
session.commit()
# Re-load the triggers
job_runner.load_triggers()
# Wait for up to 3 seconds for it to vanish from the TriggerRunner's storage
for _ in range(30):
if not job_runner.trigger_runner.triggers:
break
time.sleep(0.1)
else:
pytest.fail("TriggerRunner never deleted trigger")
finally:
# We always have to stop the runner
job_runner.trigger_runner.stop = True
job_runner.trigger_runner.join(30)
class TestTriggerRunner:
@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
async def test_run_inline_trigger_canceled(self, session) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name": "mock_name", "events": 0}}
mock_trigger = MagicMock()
mock_trigger.task_instance.trigger_timeout = None
mock_trigger.run.side_effect = asyncio.CancelledError()
with pytest.raises(asyncio.CancelledError):
await trigger_runner.run_trigger(1, mock_trigger)
@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
async def test_run_inline_trigger_timeout(self, session, caplog) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name": "mock_name", "events": 0}}
mock_trigger = MagicMock()
mock_trigger.task_instance.trigger_timeout = timezone.utcnow() - datetime.timedelta(hours=1)
mock_trigger.run.side_effect = asyncio.CancelledError()
with pytest.raises(asyncio.CancelledError):
await trigger_runner.run_trigger(1, mock_trigger)
assert "Trigger cancelled due to timeout" in caplog.text
@patch("airflow.models.trigger.Trigger.bulk_fetch")
@patch(
"airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath",
return_value=DateTimeTrigger,
)
def test_update_trigger_with_triggerer_argument_change(
self, mock_bulk_fetch, mock_get_trigger_by_classpath, session, caplog
) -> None:
trigger_runner = TriggerRunner()
mock_trigger_orm = MagicMock()
mock_trigger_orm.kwargs = {"moment": ..., "not_exists_arg": ...}
mock_get_trigger_by_classpath.return_value = {1: mock_trigger_orm}
trigger_runner.update_triggers({1})
assert "Trigger failed" in caplog.text
assert "got an unexpected keyword argument 'not_exists_arg'" in caplog.text
def test_trigger_create_race_condition_18392(session, tmp_path):
"""
This verifies the resolution of race condition documented in github issue #18392.
Triggers are queued for creation by TriggerJob.load_triggers.
There was a race condition where multiple triggers would be created unnecessarily.
What happens is the runner completes the trigger and purges from the "running" list.
Then job.load_triggers is called and it looks like the trigger is not running but should,
so it queues it again.
The scenario is as follows:
1. job.load_triggers (trigger now queued)
2. runner.create_triggers (trigger now running)
3. job.handle_events (trigger still appears running so state not updated in DB)
4. runner.cleanup_finished_triggers (trigger completed at this point; trigger from "running" set)
5. job.load_triggers (trigger not running, but also not purged from DB, so it is queued again)
6. runner.create_triggers (trigger created again)
This test verifies that under this scenario only one trigger is created.
"""
path = tmp_path / "test_trigger_bad_respawn.txt"
class TriggerRunner_(TriggerRunner):
"""We do some waiting for main thread looping"""
async def wait_for_job_method_count(self, method, count):
for _ in range(30):
await asyncio.sleep(0.1)
if getattr(self, f"{method}_count", 0) >= count:
break
else:
pytest.fail(f"did not observe count {count} in job method {method}")
async def create_triggers(self):
"""
On first run, wait for job.load_triggers to make sure they are queued
"""
if getattr(self, "loop_count", 0) == 0:
await self.wait_for_job_method_count("load_triggers", 1)
await super().create_triggers()
self.loop_count = getattr(self, "loop_count", 0) + 1
async def cleanup_finished_triggers(self):
"""On loop 1, make sure that job.handle_events was already called"""
if self.loop_count == 1:
await self.wait_for_job_method_count("handle_events", 1)
await super().cleanup_finished_triggers()
class TriggererJob_(TriggererJobRunner):
"""We do some waiting for runner thread looping (and track calls in job thread)"""
def wait_for_runner_loop(self, runner_loop_count):
for _ in range(30):
time.sleep(0.1)
if getattr(self.trigger_runner, "call_count", 0) >= runner_loop_count:
break
else:
pytest.fail("did not observe 2 loops in the runner thread")
def load_triggers(self):
"""On second run, make sure that runner has called create_triggers in its second loop"""
super().load_triggers()
self.trigger_runner.load_triggers_count = (
getattr(self.trigger_runner, "load_triggers_count", 0) + 1
)
if self.trigger_runner.load_triggers_count == 2:
self.wait_for_runner_loop(runner_loop_count=2)
def handle_events(self):
super().handle_events()
self.trigger_runner.handle_events_count = (
getattr(self.trigger_runner, "handle_events_count", 0) + 1
)
trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), filename=path.as_posix())
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
session.add(trigger_orm)
dag = DagModel(dag_id="test-dag")
dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none")
ti = TaskInstance(PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id)
ti.dag_id = dag.dag_id
ti.trigger_id = 1
session.add(dag)
session.add(dag_run)
session.add(ti)
session.commit()
job = Job()
job_runner = TriggererJob_(job)
job_runner.trigger_runner = TriggerRunner_()
thread = Thread(target=job_runner._execute)
thread.start()
try:
for _ in range(40):
time.sleep(0.1)
# ready to evaluate after 2 loops
if getattr(job_runner.trigger_runner, "loop_count", 0) >= 2:
break
else:
pytest.fail("did not observe 2 loops in the runner thread")
finally:
job_runner.trigger_runner.stop = True
job_runner.trigger_runner.join(30)
thread.join()
instances = path.read_text().splitlines()
assert len(instances) == 1
def test_trigger_from_dead_triggerer(session, create_task_instance):
"""
Checks that the triggerer will correctly claim a Trigger that is assigned to a
triggerer that does not exist.
"""
# Use a trigger that has an invalid triggerer_id
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
trigger_orm.triggerer_id = 999 # Non-existent triggerer
session.add(trigger_orm)
ti_orm = create_task_instance(
task_id="ti_orm",
execution_date=timezone.utcnow(),
run_id="orm_run_id",
)
ti_orm.trigger_id = trigger_orm.id
session.add(trigger_orm)
session.commit()
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
job_runner = TriggererJobRunner(job)
job_runner.load_triggers()
# Make sure it turned up in TriggerRunner's queue
assert [x for x, y in job_runner.trigger_runner.to_create] == [1]
def test_trigger_from_expired_triggerer(session, create_task_instance):
"""
Checks that the triggerer will correctly claim a Trigger that is assigned to a
triggerer that has an expired heartbeat.
"""
# Use a trigger assigned to the expired triggerer
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
trigger_orm.triggerer_id = 42
session.add(trigger_orm)
ti_orm = create_task_instance(
task_id="ti_orm",
execution_date=timezone.utcnow(),
run_id="orm_run_id",
)
ti_orm.trigger_id = trigger_orm.id
session.add(trigger_orm)
# Use a TriggererJobRunner with an expired heartbeat
triggerer_job_orm = Job(TriggererJobRunner.job_type)
triggerer_job_orm.id = 42
triggerer_job_orm.start_date = timezone.utcnow() - datetime.timedelta(hours=1)
triggerer_job_orm.end_date = None
triggerer_job_orm.latest_heartbeat = timezone.utcnow() - datetime.timedelta(hours=1)
session.add(triggerer_job_orm)
session.commit()
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job(TriggererJobRunner.job_type)
job_runner = TriggererJobRunner(job)
job_runner.load_triggers()
# Make sure it turned up in TriggerRunner's queue
assert [x for x, y in job_runner.trigger_runner.to_create] == [1]
def test_trigger_runner_exception_stops_triggerer(session):
"""
Checks that if an exception occurs when creating triggers, that the triggerer
process stops
"""
class MockTriggerException(Exception):
pass
class TriggerRunner_(TriggerRunner):
async def create_triggers(self):
raise MockTriggerException("Trigger creation failed")
# Use a trigger that will immediately succeed
trigger = SuccessTrigger()
create_trigger_in_db(session, trigger)
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
job_runner = TriggererJobRunner(job)
job_runner.trigger_runner = TriggerRunner_()
thread = Thread(target=job_runner._execute)
thread.start()
# Wait 4 seconds for the triggerer to stop
try:
for _ in range(40):
time.sleep(0.1)
if not thread.is_alive():
break
else:
pytest.fail("TriggererJobRunner did not stop after exception in TriggerRunner")
if not job_runner.trigger_runner.stop:
pytest.fail("TriggerRunner not marked as stopped after exception in TriggerRunner")
finally:
job_runner.trigger_runner.stop = True
# with suppress(MockTriggerException):
job_runner.trigger_runner.join(30)
thread.join()
def test_trigger_firing(session):
"""
Checks that when a trigger fires, it correctly makes it into the
event queue.
"""
# Use a trigger that will immediately succeed
trigger = SuccessTrigger()
create_trigger_in_db(session, trigger)
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
job_runner = TriggererJobRunner(job)
job_runner.load_triggers()
# Now, start TriggerRunner up (and set it as a daemon thread during tests)
job_runner.daemon = True
job_runner.trigger_runner.start()
try:
# Wait for up to 3 seconds for it to fire and appear in the event queue
for _ in range(30):
if job_runner.trigger_runner.events:
assert list(job_runner.trigger_runner.events) == [(1, TriggerEvent(True))]
break
time.sleep(0.1)
else:
pytest.fail("TriggerRunner never sent the trigger event out")
finally:
# We always have to stop the runner
job_runner.trigger_runner.stop = True
job_runner.trigger_runner.join(30)
def test_trigger_failing(session):
"""
Checks that when a trigger fails, it correctly makes it into the
failure queue.
"""
# Use a trigger that will immediately fail
trigger = FailureTrigger()
create_trigger_in_db(session, trigger)
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
job_runner = TriggererJobRunner(job)
job_runner.load_triggers()
# Now, start TriggerRunner up (and set it as a daemon thread during tests)
job_runner.daemon = True
job_runner.trigger_runner.start()
try:
# Wait for up to 3 seconds for it to fire and appear in the event queue
for _ in range(30):
if job_runner.trigger_runner.failed_triggers:
assert len(job_runner.trigger_runner.failed_triggers) == 1
trigger_id, exc = next(iter(job_runner.trigger_runner.failed_triggers))
assert trigger_id == 1
assert isinstance(exc, ValueError)
assert exc.args[0] == "Deliberate trigger failure"
break
time.sleep(0.1)
else:
pytest.fail("TriggerRunner never marked the trigger as failed")
finally:
# We always have to stop the runner
job_runner.trigger_runner.stop = True
job_runner.trigger_runner.join(30)
def test_trigger_cleanup(session):
"""
Checks that the triggerer will correctly clean up triggers that do not
have any task instances depending on them.
"""
# Use a trigger that will not fire for the lifetime of the test
# (we want to avoid it firing and deleting itself)
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
session.add(trigger_orm)
session.commit()
# Trigger the cleanup code
Trigger.clean_unused(session=session)
session.commit()
# Make sure it's gone
assert session.query(Trigger).count() == 0
def test_invalid_trigger(session, dag_maker):
"""
Checks that the triggerer will correctly fail task instances that depend on
triggers that can't even be loaded.
"""
# Create a totally invalid trigger
trigger_orm = Trigger(classpath="fake.classpath", kwargs={})
trigger_orm.id = 1
session.add(trigger_orm)
session.commit()
# Create the test DAG and task
with dag_maker(dag_id="test_invalid_trigger", session=session):
EmptyOperator(task_id="dummy1")
dr = dag_maker.create_dagrun()
task_instance = dr.task_instances[0]
# Make a task instance based on that and tie it to the trigger
task_instance.state = TaskInstanceState.DEFERRED
task_instance.trigger_id = 1
session.commit()
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
job_runner = TriggererJobRunner(job)
job_runner.load_triggers()
# Make sure it turned up in the failed queue
assert len(job_runner.trigger_runner.failed_triggers) == 1
# Run the failed trigger handler
job_runner.handle_failed_triggers()
# Make sure it marked the task instance as failed (which is actually the
# scheduled state with a payload to make it fail)
task_instance.refresh_from_db()
assert task_instance.state == TaskInstanceState.SCHEDULED
assert task_instance.next_method == "__fail__"
assert task_instance.next_kwargs["error"] == "Trigger failure"
assert task_instance.next_kwargs["traceback"][-1] == "ModuleNotFoundError: No module named 'fake'\n"
@pytest.mark.parametrize("should_wrap", (True, False))
@patch("airflow.jobs.triggerer_job_runner.configure_trigger_log_handler")
def test_handler_config_respects_donot_wrap(mock_configure, should_wrap):
from airflow.jobs import triggerer_job_runner
triggerer_job_runner.DISABLE_WRAPPER = not should_wrap
job = Job()
TriggererJobRunner(job=job)
if should_wrap:
mock_configure.assert_called()
else:
mock_configure.assert_not_called()
@patch("airflow.jobs.triggerer_job_runner.setup_queue_listener")
def test_triggerer_job_always_creates_listener(mock_setup):
mock_setup.assert_not_called()
job = Job()
TriggererJobRunner(job=job)
mock_setup.assert_called()
def test_queue_listener():
"""
When listener func called, root handlers should be moved to queue listener
and replaced with queuehandler.
"""
reset_logging()
importlib.reload(airflow_local_settings)
configure_logging()
def non_pytest_handlers(val):
return [h for h in val if "pytest" not in h.__module__]
import logging
log = logging.getLogger()
handlers = non_pytest_handlers(log.handlers)
assert len(handlers) == 1
handler = handlers[0]
assert handler.__class__ == RedirectStdHandler
listener = setup_queue_listener()
assert handler not in non_pytest_handlers(log.handlers)
qh = log.handlers[-1]
assert qh.__class__ == LocalQueueHandler
assert qh.queue == listener.queue
listener.stop()