blob: ddb911d05fc912f26563e1a4e0f57590941c2cf3 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import timedelta
from typing import TYPE_CHECKING
from unittest.mock import Mock
import pytest
import time_machine
from airflow.exceptions import (
AirflowException,
AirflowFailException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTimeout,
)
from airflow.models.trigger import TriggerFailureReason
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import TaskInstanceState, timezone
from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue, poke_mode_only
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.execution_time.comms import RescheduleTask, TaskRescheduleStartDate
from airflow.sdk.timezone import datetime
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = "unit_test_dag"
DUMMY_OP = "dummy_op"
SENSOR_OP = "sensor_op"
class DummySensor(BaseSensorOperator):
def __init__(self, return_value: bool | None = False, **kwargs):
super().__init__(**kwargs)
self.return_value = return_value
def poke(self, context: Context):
return self.return_value
class DummyAsyncSensor(BaseSensorOperator):
def __init__(self, return_value=False, **kwargs):
super().__init__(**kwargs)
self.return_value = return_value
def execute_complete(self, context, event=None):
raise AirflowException("Should be skipped")
class DummySensorWithXcomValue(BaseSensorOperator):
def __init__(self, return_value=False, xcom_value=None, **kwargs):
super().__init__(**kwargs)
self.xcom_value = xcom_value
self.return_value = return_value
def poke(self, context: Context):
return PokeReturnValue(self.return_value, self.xcom_value)
class TestBaseSensor:
@pytest.fixture
def make_sensor(self):
"""Create a DummySensor"""
def _make_sensor(return_value, task_id=SENSOR_OP, **kwargs):
poke_interval = "poke_interval"
timeout = "timeout"
if poke_interval not in kwargs:
kwargs[poke_interval] = 0
if timeout not in kwargs:
kwargs[timeout] = 0
with DAG(TEST_DAG_ID):
if "xcom_value" in kwargs:
sensor = DummySensorWithXcomValue(task_id=task_id, return_value=return_value, **kwargs)
else:
sensor = DummySensor(task_id=task_id, return_value=return_value, **kwargs)
dummy_op = EmptyOperator(task_id=DUMMY_OP)
sensor >> dummy_op
return sensor
return _make_sensor
@classmethod
def _run(cls, task, context=None):
if context is None:
context = {}
return task.execute(context)
def test_ok(self, make_sensor):
sensor = make_sensor(True)
self._run(sensor)
def test_fail(self, make_sensor):
sensor = make_sensor(False)
with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
def test_soft_fail(self, make_sensor):
sensor = make_sensor(False, soft_fail=True)
with pytest.raises(AirflowSkipException):
self._run(sensor)
@pytest.mark.parametrize(
"exception_cls",
(ValueError,),
)
def test_soft_fail_with_exception(self, make_sensor, exception_cls):
sensor = make_sensor(False, soft_fail=True)
sensor.poke = Mock(side_effect=[exception_cls(None)])
with pytest.raises(ValueError, match="None"):
self._run(sensor)
@pytest.mark.parametrize(
"exception_cls",
(
AirflowSensorTimeout,
AirflowTaskTimeout,
AirflowFailException,
),
)
def test_soft_fail_with_skip_exception(self, make_sensor, exception_cls):
sensor = make_sensor(False, soft_fail=True)
sensor.poke = Mock(side_effect=[exception_cls(None)])
with pytest.raises(AirflowSkipException):
self._run(sensor)
@pytest.mark.parametrize(
"exception_cls",
(AirflowSensorTimeout, AirflowTaskTimeout, AirflowFailException, Exception),
)
def test_never_fail_with_skip_exception(self, make_sensor, exception_cls):
sensor = make_sensor(False, never_fail=True)
sensor.poke = Mock(side_effect=[exception_cls(None)])
with pytest.raises(AirflowSkipException):
self._run(sensor)
def test_ok_with_reschedule(self, run_task, make_sensor, time_machine):
sensor = make_sensor(return_value=None, poke_interval=10, timeout=25, mode="reschedule")
sensor.poke = Mock(side_effect=[False, False, True])
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
state, msg, _ = run_task(task=sensor)
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert msg.reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
# second poke returns False and task is re-scheduled
time_machine.coordinates.shift(sensor.poke_interval)
date2 = date1 + timedelta(seconds=sensor.poke_interval)
state, msg, _ = run_task(task=sensor)
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert msg.reschedule_date == date2 + timedelta(seconds=sensor.poke_interval)
# third poke returns True and task succeeds
time_machine.coordinates.shift(sensor.poke_interval)
state, _, _ = run_task(task=sensor)
assert state == TaskInstanceState.SUCCESS
def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_supervisor_comms):
sensor = make_sensor(return_value=False, poke_interval=10, timeout=5, mode="reschedule")
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
state, msg, _ = run_task(task=sensor)
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert msg.reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
# second poke returns False, timeout occurs
time_machine.coordinates.shift(sensor.poke_interval)
# Mocking values from DB/API-server
mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1)
state, msg, error = run_task(task=sensor, context_update={"task_reschedule_count": 1})
assert state == TaskInstanceState.FAILED
assert isinstance(error, AirflowSensorTimeout)
def test_soft_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_supervisor_comms):
sensor = make_sensor(
return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode="reschedule"
)
# first poke returns False and task is re-scheduled
date1 = timezone.utcnow()
time_machine.move_to(date1, tick=False)
state, msg, _ = run_task(task=sensor)
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# second poke returns False, timeout occurs
time_machine.coordinates.shift(sensor.poke_interval)
# Mocking values from DB/API-server
mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1)
state, msg, _ = run_task(task=sensor, context_update={"task_reschedule_count": 1})
assert state == TaskInstanceState.SKIPPED
def test_ok_with_reschedule_and_exponential_backoff(
self, run_task, make_sensor, time_machine, mock_supervisor_comms
):
sensor = make_sensor(
return_value=None,
poke_interval=10,
timeout=36000,
mode="reschedule",
exponential_backoff=True,
)
false_count = 10
sensor.poke = Mock(side_effect=[False] * false_count + [True])
task_start_date = timezone.utcnow()
time_machine.move_to(task_start_date, tick=False)
curr_date = task_start_date
def run_duration():
return (timezone.utcnow() - task_start_date).total_seconds()
new_interval = 0
mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=task_start_date)
# loop poke returns false
for _poke_count in range(1, false_count + 1):
curr_date = curr_date + timedelta(seconds=new_interval)
time_machine.coordinates.shift(new_interval)
state, msg, _ = run_task(sensor, context_update={"task_reschedule_count": _poke_count})
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
old_interval = new_interval
new_interval = sensor._get_next_poke_interval(task_start_date, run_duration, _poke_count)
assert old_interval < new_interval # actual test
assert msg.reschedule_date == curr_date + timedelta(seconds=new_interval)
# last poke returns True and task succeeds
curr_date = curr_date + timedelta(seconds=new_interval)
time_machine.coordinates.shift(new_interval)
state, msg, _ = run_task(sensor, context_update={"task_reschedule_count": false_count + 1})
assert state == TaskInstanceState.SUCCESS
def test_invalid_mode(self):
with pytest.raises(AirflowException):
DummySensor(task_id="a", mode="foo")
def test_ok_with_custom_reschedule_exception(self, make_sensor, run_task):
sensor = make_sensor(return_value=None, mode="reschedule")
date1 = timezone.utcnow()
date2 = date1 + timedelta(seconds=60)
date3 = date1 + timedelta(seconds=120)
sensor.poke = Mock(
side_effect=[AirflowRescheduleException(date2), AirflowRescheduleException(date3), True]
)
# first poke returns False and task is re-scheduled
with time_machine.travel(date1, tick=False):
state, msg, error = run_task(sensor)
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert isinstance(msg, RescheduleTask)
assert msg.reschedule_date == date2
# second poke returns False and task is re-scheduled
with time_machine.travel(date2, tick=False):
state, msg, error = run_task(sensor)
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
assert isinstance(msg, RescheduleTask)
assert msg.reschedule_date == date3
# third poke returns True and task succeeds
with time_machine.travel(date3, tick=False):
state, _, _ = run_task(sensor)
assert state == TaskInstanceState.SUCCESS
def test_sensor_with_invalid_poke_interval(self):
negative_poke_interval = -10
non_number_poke_interval = "abcd"
positive_poke_interval = 10
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_1",
return_value=None,
poke_interval=negative_poke_interval,
timeout=25,
)
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_2",
return_value=None,
poke_interval=non_number_poke_interval,
timeout=25,
)
DummySensor(
task_id="test_sensor_task_3", return_value=None, poke_interval=positive_poke_interval, timeout=25
)
def test_sensor_with_invalid_timeout(self):
negative_timeout = -25
non_number_timeout = "abcd"
positive_timeout = 25
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_1", return_value=None, poke_interval=10, timeout=negative_timeout
)
with pytest.raises(AirflowException):
DummySensor(
task_id="test_sensor_task_2", return_value=None, poke_interval=10, timeout=non_number_timeout
)
DummySensor(
task_id="test_sensor_task_3", return_value=None, poke_interval=10, timeout=positive_timeout
)
def test_sensor_with_exponential_backoff_off(self):
sensor = DummySensor(
task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60, exponential_backoff=False
)
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow() - started_at).total_seconds()
assert sensor._get_next_poke_interval(started_at, run_duration, 1) == sensor.poke_interval
assert sensor._get_next_poke_interval(started_at, run_duration, 2) == sensor.poke_interval
def test_sensor_with_exponential_backoff_on(self):
sensor = DummySensor(
task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60, exponential_backoff=True
)
with time_machine.travel(DEFAULT_DATE):
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow() - started_at).total_seconds()
interval1 = sensor._get_next_poke_interval(started_at, run_duration, 1)
interval2 = sensor._get_next_poke_interval(started_at, run_duration, 2)
assert interval1 >= 0
assert interval1 <= sensor.poke_interval
assert interval2 >= sensor.poke_interval
assert interval2 > interval1
@pytest.mark.parametrize("poke_interval", [0, 0.1, 0.9, 1, 2, 3])
def test_sensor_with_exponential_backoff_on_and_small_poke_interval(self, poke_interval):
"""Test that sensor works correctly when poke_interval is small and exponential_backoff is on"""
sensor = DummySensor(
task_id=SENSOR_OP,
return_value=None,
poke_interval=poke_interval,
timeout=60,
exponential_backoff=True,
)
with time_machine.travel(DEFAULT_DATE):
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow() - started_at).total_seconds()
intervals = [
sensor._get_next_poke_interval(started_at, run_duration, retry_number)
for retry_number in range(1, 10)
]
for interval1, interval2 in zip(intervals, intervals[1:]):
# intervals should be increasing or equals
assert interval1 <= interval2
if poke_interval > 0:
# check if the intervals are increasing after some retries when poke_interval > 0
assert intervals[0] < intervals[-1]
else:
# check if the intervals are equal after some retries when poke_interval == 0
assert intervals[0] == intervals[-1]
def test_sensor_with_exponential_backoff_on_and_max_wait(self):
sensor = DummySensor(
task_id=SENSOR_OP,
return_value=None,
poke_interval=10,
timeout=60,
exponential_backoff=True,
max_wait=timedelta(seconds=30),
)
with time_machine.travel(DEFAULT_DATE):
started_at = timezone.utcnow() - timedelta(seconds=10)
def run_duration():
return (timezone.utcnow() - started_at).total_seconds()
for idx, expected in enumerate([2, 6, 13, 30, 30, 30, 30, 30]):
assert sensor._get_next_poke_interval(started_at, run_duration, idx) == expected
def test_reschedule_and_retry_timeout(self, mock_supervisor_comms, make_sensor, time_machine, run_task):
"""
Test mode="reschedule", retries and timeout configurations interact correctly.
Given a sensor configured like this:
- poke_interval=5
- timeout=10
- retries=2
- retry_delay=timedelta(seconds=3)
The test verifies two phases:
Phase 1: Initial execution until failure
00:00 Returns False try_number=1, max_tries=2, state=up_for_reschedule
00:05 Raises RuntimeError try_number=2, max_tries=2, state=up_for_retry
00:08 Returns False try_number=2, max_tries=2, state=up_for_reschedule
00:13 Raises AirflowSensorTimeout try_number=3, max_tries=2, state=failed
Phase 2: After clearing the failed sensor
00:19 Returns False try_number=3, max_tries=4, state=up_for_reschedule
00:24 Returns False try_number=3, max_tries=4, state=up_for_reschedule
00:26 Returns False try_number=3, max_tries=4, state=up_for_reschedule
00:31 Raises AirflowSensorTimeout try_number=4, max_tries=4, state=failed
"""
# Setup sensor with test configuration
sensor = make_sensor(
return_value=None,
poke_interval=5,
timeout=10,
retries=2,
retry_delay=timedelta(seconds=3),
mode="reschedule",
)
# Configure poke behavior for both phases
sensor.poke = Mock(
side_effect=[
# Phase 1
False, # Initial poke
RuntimeError, # Second poke raises error
False, # Third poke after retry
False, # Fourth poke times out
# Phase 2 (after clearing)
False, # First poke after clear
False, # Second poke after clear
False, # Third poke after clear
False, # Final poke times out
]
)
# To store the state across runs
test_state = {
"task_reschedule_count": 0,
"current_time": timezone.datetime(2025, 1, 1),
"try_number": 1,
"max_tries": sensor.retries, # Initial max_tries
"first_reschedule_date": None, # Track the first reschedule date
}
def _run_task():
"""
Helper function to run the sensor task with consistent state management.
This function:
1. Freezes current time using timemachine
2. Configures the supervisor comms mock to return the appropriate TR start date
3. Runs the task with the current state (try_number, max_tries, task_reschedule_count etc)
4. Updates the state dictionary
We use this helper to ensure consistent state management across all task runs
and to avoid duplicating the setup/teardown code for each run.
"""
time_machine.move_to(test_state["current_time"], tick=False)
# For timeout calculation, we need to use the first reschedule date
# This ensures the timeout is calculated from the start of the task
if test_state["first_reschedule_date"] is None:
mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=None)
else:
mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(
start_date=test_state["first_reschedule_date"]
)
state, msg, error = run_task(
task=sensor,
try_number=test_state["try_number"],
max_tries=test_state["max_tries"],
context_update={"task_reschedule_count": test_state["task_reschedule_count"]},
)
if state == TaskInstanceState.UP_FOR_RESCHEDULE:
test_state["task_reschedule_count"] += 1
# Only set first_reschedule_date on the first successful reschedule
if test_state["first_reschedule_date"] is None:
test_state["first_reschedule_date"] = test_state["current_time"]
elif state == TaskInstanceState.UP_FOR_RETRY:
test_state["try_number"] += 1
return state, msg, error
# Phase 1: Initial execution until failure
# First poke - should reschedule
state, _, _ = _run_task()
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# Second poke - should raise RuntimeError and retry
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
state, _, error = _run_task()
assert state == TaskInstanceState.UP_FOR_RETRY
assert isinstance(error, RuntimeError)
# Third poke - should reschedule again
test_state["current_time"] += sensor.retry_delay + timedelta(seconds=1)
state, _, _ = _run_task()
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# Fourth poke - should timeout
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
state, _, error = _run_task()
assert isinstance(error, AirflowSensorTimeout)
assert state == TaskInstanceState.FAILED
# Phase 2: After clearing the failed sensor
# Reset supervisor comms to return None, simulating a fresh start after clearing
test_state["first_reschedule_date"] = None
test_state["max_tries"] = 4 # Original max_tries (2) + retries (2)
test_state["current_time"] += timedelta(seconds=20)
# Test three reschedules after clearing
for _ in range(3):
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
state, _, _ = _run_task()
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
# Final poke - should timeout
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
state, _, error = _run_task()
assert isinstance(error, AirflowSensorTimeout)
assert state == TaskInstanceState.FAILED
def test_sensor_with_xcom(self, make_sensor):
xcom_value = "TestValue"
sensor = make_sensor(True, xcom_value=xcom_value)
assert self._run(sensor) == xcom_value
def test_sensor_with_xcom_fails(self, make_sensor):
xcom_value = "TestValue"
sensor = make_sensor(False, xcom_value=xcom_value)
with pytest.raises(AirflowSensorTimeout):
assert self._run(sensor) == xcom_value is None
def test_resume_execution(self):
op = BaseSensorOperator(task_id="hi")
with pytest.raises(AirflowSensorTimeout):
op.resume_execution(
next_method="__fail__",
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
context={},
)
@pytest.mark.parametrize("mode", ["poke", "reschedule"])
@pytest.mark.parametrize("retries", [0, 1])
def test_sensor_timeout(self, mode, retries, run_task):
"""
Test that AirflowSensorTimeout does not cause sensor to retry.
"""
from airflow.providers.standard.sensors.python import PythonSensor
def timeout():
raise AirflowSensorTimeout
task = PythonSensor(
task_id="test_raise_sensor_timeout",
python_callable=timeout,
retries=retries,
mode=mode,
)
state, _, error = run_task(task=task, dag_id=f"test_sensor_timeout_{mode}_{retries}")
assert isinstance(error, AirflowSensorTimeout)
assert state == TaskInstanceState.FAILED
@poke_mode_only
class DummyPokeOnlySensor(BaseSensorOperator):
def __init__(self, poke_changes_mode=False, **kwargs):
self.mode = kwargs["mode"]
super().__init__(**kwargs)
self.poke_changes_mode = poke_changes_mode
self.return_value = True
def poke(self, context: Context):
if self.poke_changes_mode:
self.change_mode("reschedule")
return self.return_value
def change_mode(self, mode):
self.mode = mode
class TestPokeModeOnly:
def test_poke_mode_only_allows_poke_mode(self):
try:
sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=False)
except ValueError:
self.fail("__init__ failed with mode='poke'.")
try:
sensor.poke({})
except ValueError:
self.fail("poke failed without changing mode from 'poke'.")
try:
sensor.change_mode("poke")
except ValueError:
self.fail("class method failed without changing mode from 'poke'.")
def test_poke_mode_only_bad_class_method(self):
sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=False)
with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"):
sensor.change_mode("reschedule")
def test_poke_mode_only_bad_init(self):
with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"):
DummyPokeOnlySensor(task_id="foo", mode="reschedule", poke_changes_mode=False)
def test_poke_mode_only_bad_poke(self):
sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=True)
with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"):
sensor.poke({})
class TestAsyncSensor:
@pytest.mark.parametrize(
("soft_fail", "expected_exception"),
[
(True, AirflowSkipException),
(False, AirflowException),
],
)
def test_fail_after_resuming_deferred_sensor(self, soft_fail, expected_exception):
async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", soft_fail=soft_fail)
with pytest.raises(expected_exception):
async_sensor.resume_execution("execute_complete", None, {})