| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import unittest |
| from datetime import timedelta |
| from time import sleep |
| from unittest.mock import Mock, patch |
| |
| import pytest |
| from freezegun import freeze_time |
| |
| from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout |
| from airflow.models import DagBag, TaskInstance, TaskReschedule |
| from airflow.models.dag import DAG |
| from airflow.operators.dummy import DummyOperator |
| from airflow.sensors.base import BaseSensorOperator, poke_mode_only |
| from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep |
| from airflow.utils import timezone |
| from airflow.utils.state import State |
| from airflow.utils.timezone import datetime |
| from airflow.utils.types import DagRunType |
| from tests.test_utils import db |
| |
| DEFAULT_DATE = datetime(2015, 1, 1) |
| TEST_DAG_ID = 'unit_test_dag' |
| DUMMY_OP = 'dummy_op' |
| SENSOR_OP = 'sensor_op' |
| DEV_NULL = 'dev/null' |
| |
| |
| class DummySensor(BaseSensorOperator): |
| def __init__(self, return_value=False, **kwargs): |
| super().__init__(**kwargs) |
| self.return_value = return_value |
| |
| def poke(self, context): |
| return self.return_value |
| |
| |
| class TestBaseSensor(unittest.TestCase): |
| @staticmethod |
| def clean_db(): |
| db.clear_db_runs() |
| db.clear_db_task_reschedule() |
| db.clear_db_xcom() |
| |
| def setUp(self): |
| args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} |
| self.dag = DAG(TEST_DAG_ID, default_args=args) |
| self.clean_db() |
| |
| def tearDown(self) -> None: |
| self.clean_db() |
| |
| def _make_dag_run(self): |
| return self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs): |
| poke_interval = 'poke_interval' |
| timeout = 'timeout' |
| |
| if poke_interval not in kwargs: |
| kwargs[poke_interval] = 0 |
| if timeout not in kwargs: |
| kwargs[timeout] = 0 |
| |
| sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.dag, **kwargs) |
| |
| dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag) |
| dummy_op.set_upstream(sensor) |
| return sensor |
| |
| @classmethod |
| def _run(cls, task): |
| task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_ok(self): |
| sensor = self._make_sensor(True) |
| dr = self._make_dag_run() |
| |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.SUCCESS |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_fail(self): |
| sensor = self._make_sensor(False) |
| dr = self._make_dag_run() |
| |
| with pytest.raises(AirflowSensorTimeout): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.FAILED |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_soft_fail(self): |
| sensor = self._make_sensor(False, soft_fail=True) |
| dr = self._make_dag_run() |
| |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.SKIPPED |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_soft_fail_with_retries(self): |
| sensor = self._make_sensor( |
| return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1) |
| ) |
| dr = self._make_dag_run() |
| |
| # first run fails and task instance is marked up to retry |
| with pytest.raises(AirflowSensorTimeout): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.UP_FOR_RETRY |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| sleep(0.001) |
| # after retry DAG run is skipped |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.SKIPPED |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_ok_with_reschedule(self): |
| sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') |
| sensor.poke = Mock(side_effect=[False, False, True]) |
| dr = self._make_dag_run() |
| |
| # first poke returns False and task is re-scheduled |
| date1 = timezone.utcnow() |
| with freeze_time(date1): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| # verify task is re-scheduled, i.e. state set to NONE |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| # verify task start date is the initial one |
| assert ti.start_date == date1 |
| # verify one row in task_reschedule table |
| task_reschedules = TaskReschedule.find_for_task_instance(ti) |
| assert len(task_reschedules) == 1 |
| assert task_reschedules[0].start_date == date1 |
| assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval) |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # second poke returns False and task is re-scheduled |
| date2 = date1 + timedelta(seconds=sensor.poke_interval) |
| with freeze_time(date2): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| # verify task is re-scheduled, i.e. state set to NONE |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| # verify task start date is the initial one |
| assert ti.start_date == date1 |
| # verify two rows in task_reschedule table |
| task_reschedules = TaskReschedule.find_for_task_instance(ti) |
| assert len(task_reschedules) == 2 |
| assert task_reschedules[1].start_date == date2 |
| assert task_reschedules[1].reschedule_date == date2 + timedelta(seconds=sensor.poke_interval) |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # third poke returns True and task succeeds |
| date3 = date2 + timedelta(seconds=sensor.poke_interval) |
| with freeze_time(date3): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.SUCCESS |
| # verify task start date is the initial one |
| assert ti.start_date == date1 |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_fail_with_reschedule(self): |
| sensor = self._make_sensor(return_value=False, poke_interval=10, timeout=5, mode='reschedule') |
| dr = self._make_dag_run() |
| |
| # first poke returns False and task is re-scheduled |
| date1 = timezone.utcnow() |
| with freeze_time(date1): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # second poke returns False, timeout occurs |
| date2 = date1 + timedelta(seconds=sensor.poke_interval) |
| with freeze_time(date2): |
| with pytest.raises(AirflowSensorTimeout): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.FAILED |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_soft_fail_with_reschedule(self): |
| sensor = self._make_sensor( |
| return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode='reschedule' |
| ) |
| dr = self._make_dag_run() |
| |
| # first poke returns False and task is re-scheduled |
| date1 = timezone.utcnow() |
| with freeze_time(date1): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # second poke returns False, timeout occurs |
| date2 = date1 + timedelta(seconds=sensor.poke_interval) |
| with freeze_time(date2): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.SKIPPED |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_ok_with_reschedule_and_retry(self): |
| sensor = self._make_sensor( |
| return_value=None, |
| poke_interval=10, |
| timeout=5, |
| retries=1, |
| retry_delay=timedelta(seconds=10), |
| mode='reschedule', |
| ) |
| sensor.poke = Mock(side_effect=[False, False, False, True]) |
| dr = self._make_dag_run() |
| |
| # first poke returns False and task is re-scheduled |
| date1 = timezone.utcnow() |
| with freeze_time(date1): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| # verify one row in task_reschedule table |
| task_reschedules = TaskReschedule.find_for_task_instance(ti) |
| assert len(task_reschedules) == 1 |
| assert task_reschedules[0].start_date == date1 |
| assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval) |
| assert task_reschedules[0].try_number == 1 |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # second poke fails and task instance is marked up to retry |
| date2 = date1 + timedelta(seconds=sensor.poke_interval) |
| with freeze_time(date2): |
| with pytest.raises(AirflowSensorTimeout): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.UP_FOR_RETRY |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # third poke returns False and task is rescheduled again |
| date3 = date2 + timedelta(seconds=sensor.poke_interval) + sensor.retry_delay |
| with freeze_time(date3): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| # verify one row in task_reschedule table |
| task_reschedules = TaskReschedule.find_for_task_instance(ti) |
| assert len(task_reschedules) == 1 |
| assert task_reschedules[0].start_date == date3 |
| assert task_reschedules[0].reschedule_date == date3 + timedelta(seconds=sensor.poke_interval) |
| assert task_reschedules[0].try_number == 2 |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # fourth poke return True and task succeeds |
| date4 = date3 + timedelta(seconds=sensor.poke_interval) |
| with freeze_time(date4): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.SUCCESS |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_should_include_ready_to_reschedule_dep_in_reschedule_mode(self): |
| sensor = self._make_sensor(True, mode='reschedule') |
| deps = sensor.deps |
| assert ReadyToRescheduleDep() in deps |
| |
| def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self): |
| sensor = self._make_sensor(True) |
| deps = sensor.deps |
| assert ReadyToRescheduleDep() not in deps |
| |
| def test_invalid_mode(self): |
| with pytest.raises(AirflowException): |
| self._make_sensor(return_value=True, mode='foo') |
| |
| def test_ok_with_custom_reschedule_exception(self): |
| sensor = self._make_sensor(return_value=None, mode='reschedule') |
| date1 = timezone.utcnow() |
| date2 = date1 + timedelta(seconds=60) |
| date3 = date1 + timedelta(seconds=120) |
| sensor.poke = Mock( |
| side_effect=[ |
| AirflowRescheduleException(date2), |
| AirflowRescheduleException(date3), |
| True, |
| ] |
| ) |
| dr = self._make_dag_run() |
| |
| # first poke returns False and task is re-scheduled |
| with freeze_time(date1): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| # verify task is re-scheduled, i.e. state set to NONE |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| # verify one row in task_reschedule table |
| task_reschedules = TaskReschedule.find_for_task_instance(ti) |
| assert len(task_reschedules) == 1 |
| assert task_reschedules[0].start_date == date1 |
| assert task_reschedules[0].reschedule_date == date2 |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # second poke returns False and task is re-scheduled |
| with freeze_time(date2): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| # verify task is re-scheduled, i.e. state set to NONE |
| assert ti.state == State.UP_FOR_RESCHEDULE |
| # verify two rows in task_reschedule table |
| task_reschedules = TaskReschedule.find_for_task_instance(ti) |
| assert len(task_reschedules) == 2 |
| assert task_reschedules[1].start_date == date2 |
| assert task_reschedules[1].reschedule_date == date3 |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| # third poke returns True and task succeeds |
| with freeze_time(date3): |
| self._run(sensor) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| assert ti.state == State.SUCCESS |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_reschedule_with_test_mode(self): |
| sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule') |
| sensor.poke = Mock(side_effect=[False]) |
| dr = self._make_dag_run() |
| |
| # poke returns False and AirflowRescheduleException is raised |
| date1 = timezone.utcnow() |
| with freeze_time(date1): |
| for date in self.dag.date_range(DEFAULT_DATE, end_date=DEFAULT_DATE): |
| TaskInstance(sensor, date).run(ignore_ti_state=True, test_mode=True) |
| tis = dr.get_task_instances() |
| assert len(tis) == 2 |
| for ti in tis: |
| if ti.task_id == SENSOR_OP: |
| # in test mode state is not modified |
| assert ti.state == State.NONE |
| # in test mode no reschedule request is recorded |
| task_reschedules = TaskReschedule.find_for_task_instance(ti) |
| assert len(task_reschedules) == 0 |
| if ti.task_id == DUMMY_OP: |
| assert ti.state == State.NONE |
| |
| def test_sensor_with_invalid_poke_interval(self): |
| negative_poke_interval = -10 |
| non_number_poke_interval = "abcd" |
| positive_poke_interval = 10 |
| with pytest.raises(AirflowException): |
| self._make_sensor( |
| task_id='test_sensor_task_1', |
| return_value=None, |
| poke_interval=negative_poke_interval, |
| timeout=25, |
| ) |
| |
| with pytest.raises(AirflowException): |
| self._make_sensor( |
| task_id='test_sensor_task_2', |
| return_value=None, |
| poke_interval=non_number_poke_interval, |
| timeout=25, |
| ) |
| |
| self._make_sensor( |
| task_id='test_sensor_task_3', return_value=None, poke_interval=positive_poke_interval, timeout=25 |
| ) |
| |
| def test_sensor_with_invalid_timeout(self): |
| negative_timeout = -25 |
| non_number_timeout = "abcd" |
| positive_timeout = 25 |
| with pytest.raises(AirflowException): |
| self._make_sensor( |
| task_id='test_sensor_task_1', return_value=None, poke_interval=10, timeout=negative_timeout |
| ) |
| |
| with pytest.raises(AirflowException): |
| self._make_sensor( |
| task_id='test_sensor_task_2', return_value=None, poke_interval=10, timeout=non_number_timeout |
| ) |
| |
| self._make_sensor( |
| task_id='test_sensor_task_3', return_value=None, poke_interval=10, timeout=positive_timeout |
| ) |
| |
| def test_sensor_with_exponential_backoff_off(self): |
| sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=False) |
| |
| started_at = timezone.utcnow() - timedelta(seconds=10) |
| |
| def run_duration(): |
| return (timezone.utcnow - started_at).total_seconds() |
| |
| assert sensor._get_next_poke_interval(started_at, run_duration, 1) == sensor.poke_interval |
| assert sensor._get_next_poke_interval(started_at, run_duration, 2) == sensor.poke_interval |
| |
| def test_sensor_with_exponential_backoff_on(self): |
| |
| sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=True) |
| |
| with patch('airflow.utils.timezone.utcnow') as mock_utctime: |
| mock_utctime.return_value = DEFAULT_DATE |
| |
| started_at = timezone.utcnow() - timedelta(seconds=10) |
| |
| def run_duration(): |
| return (timezone.utcnow - started_at).total_seconds() |
| |
| interval1 = sensor._get_next_poke_interval(started_at, run_duration, 1) |
| interval2 = sensor._get_next_poke_interval(started_at, run_duration, 2) |
| |
| assert interval1 >= 0 |
| assert interval1 <= sensor.poke_interval |
| assert interval2 >= sensor.poke_interval |
| assert interval2 > interval1 |
| |
| |
| @poke_mode_only |
| class DummyPokeOnlySensor(BaseSensorOperator): |
| def __init__(self, poke_changes_mode=False, **kwargs): |
| self.mode = kwargs['mode'] |
| super().__init__(**kwargs) |
| self.poke_changes_mode = poke_changes_mode |
| self.return_value = True |
| |
| def poke(self, context): |
| if self.poke_changes_mode: |
| self.change_mode('reschedule') |
| return self.return_value |
| |
| def change_mode(self, mode): |
| self.mode = mode |
| |
| |
| class TestPokeModeOnly(unittest.TestCase): |
| def setUp(self): |
| self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) |
| self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} |
| self.dag = DAG(TEST_DAG_ID, default_args=self.args) |
| |
| def test_poke_mode_only_allows_poke_mode(self): |
| try: |
| sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag) |
| except ValueError: |
| self.fail("__init__ failed with mode='poke'.") |
| try: |
| sensor.poke({}) |
| except ValueError: |
| self.fail("poke failed without changing mode from 'poke'.") |
| try: |
| sensor.change_mode('poke') |
| except ValueError: |
| self.fail("class method failed without changing mode from 'poke'.") |
| |
| def test_poke_mode_only_bad_class_method(self): |
| sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag) |
| with pytest.raises(ValueError): |
| sensor.change_mode('reschedule') |
| |
| def test_poke_mode_only_bad_init(self): |
| with pytest.raises(ValueError): |
| DummyPokeOnlySensor(task_id='foo', mode='reschedule', poke_changes_mode=False, dag=self.dag) |
| |
| def test_poke_mode_only_bad_poke(self): |
| sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True, dag=self.dag) |
| with pytest.raises(ValueError): |
| sensor.poke({}) |