blob: dbf62883a50bb23006ff81b85e45191dd5e2fc16 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import timedelta
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch
import pytest
import time_machine
from airflow.exceptions import (
AirflowException,
AirflowFailException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTimeout,
)
from airflow.executors.debug_executor import DebugExecutor
from airflow.executors.executor_constants import (
CELERY_EXECUTOR,
CELERY_KUBERNETES_EXECUTOR,
DEBUG_EXECUTOR,
KUBERNETES_EXECUTOR,
LOCAL_EXECUTOR,
LOCAL_KUBERNETES_EXECUTOR,
SEQUENTIAL_EXECUTOR,
)
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.models import TaskInstance, TaskReschedule
from airflow.models.xcom import XCom
from airflow.operators.empty import EmptyOperator
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from airflow.providers.celery.executors.celery_kubernetes_executor import CeleryKubernetesExecutor
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor
from airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor import LocalKubernetesExecutor
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from tests.test_utils import db
pytestmark = pytest.mark.db_test
if TYPE_CHECKING:
from airflow.utils.context import Context
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = "unit_test_dag"
DUMMY_OP = "dummy_op"
SENSOR_OP = "sensor_op"
DEV_NULL = "dev/null"
@pytest.fixture
def task_reschedules_for_ti():
def wrapper(ti):
with create_session() as session:
return session.scalars(TaskReschedule.stmt_for_task_instance(ti=ti, descending=False)).all()
return wrapper
class DummySensor(BaseSensorOperator):
def __init__(self, return_value=False, **kwargs):
super().__init__(**kwargs)
self.return_value = return_value
def poke(self, context: Context):
return self.return_value
class DummyAsyncSensor(BaseSensorOperator):
def __init__(self, return_value=False, **kwargs):
super().__init__(**kwargs)
self.return_value = return_value
def execute_complete(self, context, event=None):
raise AirflowException("Should be skipped")
class DummySensorWithXcomValue(BaseSensorOperator):
def __init__(self, return_value=False, xcom_value=None, **kwargs):
super().__init__(**kwargs)
self.xcom_value = xcom_value
self.return_value = return_value
def poke(self, context: Context):
return PokeReturnValue(self.return_value, self.xcom_value)
class TestBaseSensor:
@staticmethod
def clean_db():
db.clear_db_runs()
db.clear_db_task_reschedule()
db.clear_db_xcom()
@pytest.fixture(autouse=True)
def _auto_clean(self, dag_maker):
"""(auto use)"""
self.clean_db()
yield
self.clean_db()
@pytest.fixture
def make_sensor(self, dag_maker):
"""Create a DummySensor and associated DagRun"""
def _make_sensor(return_value, task_id=SENSOR_OP, **kwargs):
poke_interval = "poke_interval"
timeout = "timeout"
if poke_interval not in kwargs:
kwargs[poke_interval] = 0
if timeout not in kwargs:
kwargs[timeout] = 0
with dag_maker(TEST_DAG_ID):
if "xcom_value" in kwargs:
sensor = DummySensorWithXcomValue(task_id=task_id, return_value=return_value, **kwargs)
else:
sensor = DummySensor(task_id=task_id, return_value=return_value, **kwargs)
dummy_op = EmptyOperator(task_id=DUMMY_OP)
sensor >> dummy_op
return sensor, dag_maker.create_dagrun()
return _make_sensor
@classmethod
def _run(cls, task, **kwargs):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True, **kwargs)
def test_ok(self, make_sensor):
sensor, dr = make_sensor(True)
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SUCCESS
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
def test_fail(self, make_sensor):
sensor, dr = make_sensor(False)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.FAILED
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
def test_soft_fail(self, make_sensor):
sensor, dr = make_sensor(False, soft_fail=True)
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SKIPPED
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
@pytest.mark.parametrize(
"exception_cls",
(
AirflowSensorTimeout,
AirflowTaskTimeout,
AirflowFailException,
Exception,
),
)
def test_soft_fail_with_non_skip_exception(self, make_sensor, exception_cls):
sensor, dr = make_sensor(False, soft_fail=True)
sensor.poke = Mock(side_effect=[exception_cls(None)])
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SKIPPED
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
def test_soft_fail_with_retries(self, make_sensor):
sensor, dr = make_sensor(
return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1)
)
# first run times out and task instance is skipped
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SKIPPED
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
def test_ok_with_reschedule(self, make_sensor, time_machine, task_reschedules_for_ti):
sensor, dr = make_sensor(return_value=None, poke_interval=10, timeout=25, mode="reschedule")
sensor.poke = Mock(side_effect=[False, False, True])
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
assert ti.state == State.UP_FOR_RESCHEDULE
# verify task start date is the initial one
assert ti.start_date == date1
# verify one row in task_reschedule table
task_reschedules = task_reschedules_for_ti(ti)
assert len(task_reschedules) == 1
assert task_reschedules[0].start_date == date1
assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
# second poke returns False and task is re-scheduled
time_machine.coordinates.shift(sensor.poke_interval)
date2 = date1 + timedelta(seconds=sensor.poke_interval)
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
assert ti.state == State.UP_FOR_RESCHEDULE
# verify task start date is the initial one
assert ti.start_date == date1
# verify two rows in task_reschedule table
task_reschedules = task_reschedules_for_ti(ti)
assert len(task_reschedules) == 2
assert task_reschedules[1].start_date == date2
assert task_reschedules[1].reschedule_date == date2 + timedelta(seconds=sensor.poke_interval)
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
# third poke returns True and task succeeds
time_machine.coordinates.shift(sensor.poke_interval)
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SUCCESS
# verify task start date is the initial one
assert ti.start_date == date1
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
def test_fail_with_reschedule(self, make_sensor, time_machine, session):
sensor, dr = make_sensor(return_value=False, poke_interval=10, timeout=5, mode="reschedule")
def _get_tis():
tis = dr.get_task_instances(session=session)
assert len(tis) == 2
yield next(x for x in tis if x.task_id == SENSOR_OP)
yield next(x for x in tis if x.task_id == DUMMY_OP)
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
sensor_ti, dummy_ti = _get_tis()
assert dummy_ti.state == State.NONE
assert sensor_ti.state == State.NONE
# ordinarily the scheduler does this
sensor_ti.state = State.SCHEDULED
sensor_ti.try_number += 1 # first TI run
session.commit()
self._run(sensor, session=session)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
assert dummy_ti.state == State.NONE
# second poke returns False, timeout occurs
time_machine.coordinates.shift(sensor.poke_interval)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor, session=session)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.FAILED
assert dummy_ti.state == State.NONE
def test_soft_fail_with_reschedule(self, make_sensor, time_machine, session):
sensor, dr = make_sensor(
return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode="reschedule"
)
def _get_tis():
tis = dr.get_task_instances(session=session)
assert len(tis) == 2
yield next(x for x in tis if x.task_id == SENSOR_OP)
yield next(x for x in tis if x.task_id == DUMMY_OP)
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
sensor_ti, dummy_ti = _get_tis()
sensor_ti.try_number += 1 # second TI run
session.commit()
self._run(sensor)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
assert dummy_ti.state == State.NONE
# second poke returns False, timeout occurs
time_machine.coordinates.shift(sensor.poke_interval)
self._run(sensor)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.SKIPPED
assert dummy_ti.state == State.NONE
def test_ok_with_reschedule_and_retry(self, make_sensor, time_machine, task_reschedules_for_ti, session):
sensor, dr = make_sensor(
return_value=None,
poke_interval=10,
timeout=5,
retries=1,
retry_delay=timedelta(seconds=10),
mode="reschedule",
)
def _get_tis():
tis = dr.get_task_instances(session=session)
assert len(tis) == 2
yield next(x for x in tis if x.task_id == SENSOR_OP)
yield next(x for x in tis if x.task_id == DUMMY_OP)
sensor.poke = Mock(side_effect=[False, False, False, True])
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
sensor_ti, dummy_ti = _get_tis()
sensor_ti.try_number += 1 # first TI run
session.commit()
self._run(sensor)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# verify one row in task_reschedule table
task_reschedules = task_reschedules_for_ti(sensor_ti)
assert len(task_reschedules) == 1
assert task_reschedules[0].start_date == date1
assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
assert task_reschedules[0].try_number == 1
assert dummy_ti.state == State.NONE
# second poke timesout and task instance is failed
time_machine.coordinates.shift(sensor.poke_interval)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.FAILED
assert dummy_ti.state == State.NONE
# Task is cleared
sensor.clear()
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 2
sensor_ti, dummy_ti = _get_tis()
sensor_ti.try_number += 1 # second TI run
session.commit()
# third poke returns False and task is rescheduled again
date3 = date1 + timedelta(seconds=sensor.poke_interval) * 2 + sensor.retry_delay
time_machine.coordinates.shift(sensor.poke_interval + sensor.retry_delay.total_seconds())
self._run(sensor)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# verify one row in task_reschedule table
task_reschedules = task_reschedules_for_ti(sensor_ti)
assert len(task_reschedules) == 1
assert task_reschedules[0].start_date == date3
assert task_reschedules[0].reschedule_date == date3 + timedelta(seconds=sensor.poke_interval)
assert task_reschedules[0].try_number == 2
assert dummy_ti.state == State.NONE
# fourth poke return True and task succeeds
time_machine.coordinates.shift(sensor.poke_interval)
self._run(sensor)
sensor_ti, dummy_ti = _get_tis()
assert sensor_ti.state == State.SUCCESS
assert dummy_ti.state == State.NONE
@pytest.mark.parametrize("mode", ["poke", "reschedule"])
def test_should_include_ready_to_reschedule_dep(self, mode):
sensor = DummySensor(task_id="a", return_value=True, mode=mode)
deps = sensor.deps
assert ReadyToRescheduleDep() in deps
def test_invalid_mode(self):
with pytest.raises(AirflowException):
DummySensor(task_id="a", mode="foo")
def test_ok_with_custom_reschedule_exception(self, make_sensor, task_reschedules_for_ti):
sensor, dr = make_sensor(return_value=None, mode="reschedule")
date1 = timezone.utcnow()
date2 = date1 + timedelta(seconds=60)
date3 = date1 + timedelta(seconds=120)
sensor.poke = Mock(
side_effect=[AirflowRescheduleException(date2), AirflowRescheduleException(date3), True]
)
# first poke returns False and task is re-scheduled
with time_machine.travel(date1, tick=False):
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
assert ti.state == State.UP_FOR_RESCHEDULE
# verify one row in task_reschedule table
task_reschedules = task_reschedules_for_ti(ti)
assert len(task_reschedules) == 1
assert task_reschedules[0].start_date == date1
assert task_reschedules[0].reschedule_date == date2
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
# second poke returns False and task is re-scheduled
with time_machine.travel(date2, tick=False):
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
assert ti.state == State.UP_FOR_RESCHEDULE
# verify two rows in task_reschedule table
task_reschedules = task_reschedules_for_ti(ti)
assert len(task_reschedules) == 2
assert task_reschedules[1].start_date == date2
assert task_reschedules[1].reschedule_date == date3
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
# third poke returns True and task succeeds
with time_machine.travel(date3, tick=False):
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SUCCESS
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
def test_reschedule_with_test_mode(self, make_sensor, task_reschedules_for_ti):
sensor, dr = make_sensor(return_value=None, poke_interval=10, timeout=25, mode="reschedule")
sensor.poke = Mock(side_effect=[False])
# poke returns False and AirflowRescheduleException is raised
date1 = timezone.utcnow()
with time_machine.travel(date1, tick=False):
self._run(sensor, test_mode=True)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# in test mode state is not modified
assert ti.state == State.NONE
# in test mode no reschedule request is recorded
task_reschedules = task_reschedules_for_ti(ti)
assert len(task_reschedules) == 0
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
def test_sensor_with_invalid_poke_interval(self):
negative_poke_interval = -10
non_number_poke_interval = "abcd"
positive_poke_interval = 10
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_1",
return_value=None,
poke_interval=negative_poke_interval,
timeout=25,
)
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_2",
return_value=None,
poke_interval=non_number_poke_interval,
timeout=25,
)
DummySensor(
task_id="test_sensor_task_3", return_value=None, poke_interval=positive_poke_interval, timeout=25
)
def test_sensor_with_invalid_timeout(self):
negative_timeout = -25
non_number_timeout = "abcd"
positive_timeout = 25
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_1", return_value=None, poke_interval=10, timeout=negative_timeout
)
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_2", return_value=None, poke_interval=10, timeout=non_number_timeout
)
DummySensor(
task_id="test_sensor_task_3", return_value=None, poke_interval=10, timeout=positive_timeout
)
def test_sensor_with_exponential_backoff_off(self):
sensor = DummySensor(
task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60, exponential_backoff=False
)
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow - started_at).total_seconds()
assert sensor._get_next_poke_interval(started_at, run_duration, 1) == sensor.poke_interval
assert sensor._get_next_poke_interval(started_at, run_duration, 2) == sensor.poke_interval
def test_sensor_with_exponential_backoff_on(self):
sensor = DummySensor(
task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60, exponential_backoff=True
)
with patch("airflow.utils.timezone.utcnow") as mock_utctime:
mock_utctime.return_value = DEFAULT_DATE
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow - started_at).total_seconds()
interval1 = sensor._get_next_poke_interval(started_at, run_duration, 1)
interval2 = sensor._get_next_poke_interval(started_at, run_duration, 2)
assert interval1 >= 0
assert interval1 <= sensor.poke_interval
assert interval2 >= sensor.poke_interval
assert interval2 > interval1
@pytest.mark.parametrize("poke_interval", [0, 0.1, 0.9, 1, 2, 3])
def test_sensor_with_exponential_backoff_on_and_small_poke_interval(self, poke_interval):
"""Test that sensor works correctly when poke_interval is small and exponential_backoff is on"""
sensor = DummySensor(
task_id=SENSOR_OP,
return_value=None,
poke_interval=poke_interval,
timeout=60,
exponential_backoff=True,
)
with patch("airflow.utils.timezone.utcnow") as mock_utctime:
mock_utctime.return_value = DEFAULT_DATE
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow - started_at).total_seconds()
intervals = [
sensor._get_next_poke_interval(started_at, run_duration, retry_number)
for retry_number in range(1, 10)
]
for interval1, interval2 in zip(intervals, intervals[1:]):
# intervals should be increasing or equals
assert interval1 <= interval2
if poke_interval > 0:
# check if the intervals are increasing after some retries when poke_interval > 0
assert intervals[0] < intervals[-1]
else:
# check if the intervals are equal after some retries when poke_interval == 0
assert intervals[0] == intervals[-1]
def test_sensor_with_exponential_backoff_on_and_max_wait(self):
sensor = DummySensor(
task_id=SENSOR_OP,
return_value=None,
poke_interval=10,
timeout=60,
exponential_backoff=True,
max_wait=timedelta(seconds=30),
)
with patch("airflow.utils.timezone.utcnow") as mock_utctime:
mock_utctime.return_value = DEFAULT_DATE
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow - started_at).total_seconds()
for idx, expected in enumerate([2, 6, 13, 30, 30, 30, 30, 30]):
assert sensor._get_next_poke_interval(started_at, run_duration, idx) == expected
@pytest.mark.backend("mysql")
def test_reschedule_poke_interval_too_long_on_mysql(self, make_sensor):
with pytest.raises(AirflowException) as ctx:
make_sensor(poke_interval=863998946, mode="reschedule", return_value="irrelevant")
assert str(ctx.value) == (
"Cannot set poke_interval to 863998946.0 seconds in reschedule mode "
"since it will take reschedule time over MySQL's TIMESTAMP limit."
)
@pytest.mark.backend("mysql")
def test_reschedule_date_too_late_on_mysql(self, make_sensor):
sensor, _ = make_sensor(poke_interval=60 * 60 * 24, mode="reschedule", return_value=False)
# A few hours until TIMESTAMP's limit, the next poke will take us over.
with time_machine.travel(datetime(2038, 1, 19, tzinfo=timezone.utc), tick=False):
with pytest.raises(AirflowSensorTimeout) as ctx:
self._run(sensor)
assert str(ctx.value) == (
"Cannot reschedule DAG unit_test_dag to 2038-01-20T00:00:00+00:00 "
"since it is over MySQL's TIMESTAMP storage limit."
)
def test_reschedule_and_retry_timeout(self, make_sensor, time_machine, session):
"""
Test mode="reschedule", retries and timeout configurations interact correctly.
Given a sensor configured like this:
poke_interval=5
timeout=10
retries=2
retry_delay=timedelta(seconds=3)
If the second poke raises RuntimeError, all other pokes return False, this is how it should
behave:
00:00 Returns False try_number=1, max_tries=2, state=up_for_reschedule
00:05 Raises RuntimeError try_number=2, max_tries=2, state=up_for_retry
00:08 Returns False try_number=2, max_tries=2, state=up_for_reschedule
00:13 Raises AirflowSensorTimeout try_number=3, max_tries=2, state=failed
And then the sensor is cleared at 00:19. It should behave like this:
00:19 Returns False try_number=3, max_tries=4, state=up_for_reschedule
00:24 Returns False try_number=3, max_tries=4, state=up_for_reschedule
00:26 Returns False try_number=3, max_tries=4, state=up_for_reschedule
00:31 Raises AirflowSensorTimeout, try_number=4, max_tries=4, state=failed
"""
sensor, dr = make_sensor(
return_value=None,
poke_interval=5,
timeout=10,
retries=2,
retry_delay=timedelta(seconds=3),
mode="reschedule",
)
sensor.poke = Mock(side_effect=[False, RuntimeError, False, False, False, False, False, False])
def _get_sensor_ti():
tis = dr.get_task_instances(session=session)
return next(x for x in tis if x.task_id == SENSOR_OP)
def _increment_try_number():
sensor_ti = _get_sensor_ti()
sensor_ti.try_number += 1
session.commit()
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
_increment_try_number() # first TI run
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# second poke raises RuntimeError and task instance retries
time_machine.coordinates.shift(sensor.poke_interval)
with pytest.raises(RuntimeError):
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.UP_FOR_RETRY
# third poke returns False and task is rescheduled again
time_machine.coordinates.shift(sensor.retry_delay + timedelta(seconds=1))
_increment_try_number() # second TI run
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 2
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# fourth poke times out and raises AirflowSensorTimeout
time_machine.coordinates.shift(sensor.poke_interval)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 2
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.FAILED
# Clear the failed sensor
sensor.clear()
sensor_ti = _get_sensor_ti()
# clearing does not change the try_number
assert sensor_ti.try_number == 2
# but it does change the max_tries
assert sensor_ti.max_tries == 4
assert sensor_ti.state is None
time_machine.coordinates.shift(20)
_increment_try_number() # third TI run
for _ in range(3):
time_machine.coordinates.shift(sensor.poke_interval)
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 3
assert sensor_ti.max_tries == 4
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# Last poke times out and raises AirflowSensorTimeout
time_machine.coordinates.shift(sensor.poke_interval)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 3
assert sensor_ti.max_tries == 4
assert sensor_ti.state == State.FAILED
def test_reschedule_and_retry_timeout_and_silent_fail(self, make_sensor, time_machine, session):
"""
Test mode="reschedule", silent_fail=True then retries and timeout configurations interact correctly.
Given a sensor configured like this:
poke_interval=5
timeout=10
retries=2
retry_delay=timedelta(seconds=3)
silent_fail=True
If the second poke raises RuntimeError, all other pokes return False, this is how it should
behave:
00:00 Returns False try_number=1, max_tries=2, state=up_for_reschedule
00:05 Raises RuntimeError try_number=1, max_tries=2, state=up_for_reschedule
00:08 Returns False try_number=1, max_tries=2, state=up_for_reschedule
00:13 Raises AirflowSensorTimeout try_number=2, max_tries=2, state=failed
And then the sensor is cleared at 00:19. It should behave like this:
00:19 Returns False try_number=2, max_tries=3, state=up_for_reschedule
00:24 Returns False try_number=2, max_tries=3, state=up_for_reschedule
00:26 Returns False try_number=2, max_tries=3, state=up_for_reschedule
00:31 Raises AirflowSensorTimeout, try_number=3, max_tries=3, state=failed
"""
sensor, dr = make_sensor(
return_value=None,
poke_interval=5,
timeout=10,
retries=2,
retry_delay=timedelta(seconds=3),
mode="reschedule",
silent_fail=True,
)
def _get_sensor_ti():
tis = dr.get_task_instances(session=session)
return next(x for x in tis if x.task_id == SENSOR_OP)
sensor.poke = Mock(side_effect=[False, RuntimeError, False, False, False, False, False, False])
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
sensor_ti = _get_sensor_ti()
sensor_ti.try_number += 1 # first TI run
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# second poke raises reschedule exception and task instance is re-scheduled again
time_machine.coordinates.shift(sensor.poke_interval)
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# third poke returns False and task is rescheduled again
time_machine.coordinates.shift(sensor.retry_delay + timedelta(seconds=1))
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# fourth poke times out and raises AirflowSensorTimeout
time_machine.coordinates.shift(sensor.poke_interval)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 2
assert sensor_ti.state == State.FAILED
# Clear the failed sensor
sensor.clear()
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 1
assert sensor_ti.max_tries == 3
assert sensor_ti.state == State.NONE
time_machine.coordinates.shift(20)
sensor_ti.try_number += 1 # second TI run
session.commit()
for _ in range(3):
time_machine.coordinates.shift(sensor.poke_interval)
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 2
assert sensor_ti.max_tries == 3
assert sensor_ti.state == State.UP_FOR_RESCHEDULE
# Last poke times out and raises AirflowSensorTimeout
time_machine.coordinates.shift(sensor.poke_interval)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
sensor_ti = _get_sensor_ti()
assert sensor_ti.try_number == 2
assert sensor_ti.max_tries == 3
assert sensor_ti.state == State.FAILED
def test_sensor_with_xcom(self, make_sensor):
xcom_value = "TestValue"
sensor, dr = make_sensor(True, xcom_value=xcom_value)
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SUCCESS
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
actual_xcom_value = XCom.get_one(
key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id
)
assert actual_xcom_value == xcom_value
def test_sensor_with_xcom_fails(self, make_sensor):
xcom_value = "TestValue"
sensor, dr = make_sensor(False, xcom_value=xcom_value)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.FAILED
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
actual_xcom_value = XCom.get_one(
key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id
)
assert actual_xcom_value is None
@pytest.mark.parametrize(
"executor_cls_mode",
[
(CeleryExecutor, "poke"),
(CeleryKubernetesExecutor, "poke"),
(DebugExecutor, "reschedule"),
(KubernetesExecutor, "poke"),
(LocalExecutor, "poke"),
(LocalKubernetesExecutor, "poke"),
(SequentialExecutor, "poke"),
],
ids=[
CELERY_EXECUTOR,
CELERY_KUBERNETES_EXECUTOR,
DEBUG_EXECUTOR,
KUBERNETES_EXECUTOR,
LOCAL_EXECUTOR,
LOCAL_KUBERNETES_EXECUTOR,
SEQUENTIAL_EXECUTOR,
],
)
def test_prepare_for_execution(self, executor_cls_mode):
"""
Should change mode of the task to reschedule if using DEBUG_EXECUTOR
"""
executor_cls, mode = executor_cls_mode
sensor = DummySensor(
task_id=SENSOR_OP,
return_value=None,
poke_interval=10,
timeout=60,
exponential_backoff=True,
max_wait=timedelta(seconds=30),
)
with patch(
"airflow.executors.executor_loader.ExecutorLoader.import_default_executor_cls"
) as load_executor:
load_executor.return_value = (executor_cls, None)
task = sensor.prepare_for_execution()
assert task.mode == mode
@poke_mode_only
class DummyPokeOnlySensor(BaseSensorOperator):
def __init__(self, poke_changes_mode=False, **kwargs):
self.mode = kwargs["mode"]
super().__init__(**kwargs)
self.poke_changes_mode = poke_changes_mode
self.return_value = True
def poke(self, context: Context):
if self.poke_changes_mode:
self.change_mode("reschedule")
return self.return_value
def change_mode(self, mode):
self.mode = mode
class TestPokeModeOnly:
def test_poke_mode_only_allows_poke_mode(self):
try:
sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=False)
except ValueError:
self.fail("__init__ failed with mode='poke'.")
try:
sensor.poke({})
except ValueError:
self.fail("poke failed without changing mode from 'poke'.")
try:
sensor.change_mode("poke")
except ValueError:
self.fail("class method failed without changing mode from 'poke'.")
def test_poke_mode_only_bad_class_method(self):
sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=False)
with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"):
sensor.change_mode("reschedule")
def test_poke_mode_only_bad_init(self):
with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"):
DummyPokeOnlySensor(task_id="foo", mode="reschedule", poke_changes_mode=False)
def test_poke_mode_only_bad_poke(self):
sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=True)
with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"):
sensor.poke({})
class TestAsyncSensor:
@pytest.mark.parametrize(
"soft_fail, expected_exception",
[
(True, AirflowSkipException),
(False, AirflowException),
],
)
def test_fail_after_resuming_deferred_sensor(self, soft_fail, expected_exception):
async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", soft_fail=soft_fail)
ti = TaskInstance(task=async_sensor)
ti.next_method = "execute_complete"
with pytest.raises(expected_exception):
ti._execute_task({}, None)