blob: 9bce3646c00fcbfe928b51fd2069e5480eb13535 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import itertools
import logging
import os
import re
import tempfile
import zipfile
from datetime import time, timedelta
from unittest import mock
import pytest
from airflow import settings
from airflow.decorators import task as task_deco
from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred
from airflow.models import DagBag, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.xcom_arg import XComArg
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.sensors.external_task import (
ExternalTaskMarker,
ExternalTaskSensor,
)
from airflow.providers.standard.sensors.time import TimeSensor
from airflow.providers.standard.triggers.external_task import WorkflowTrigger
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.timetables.base import DataInterval
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.timezone import coerce_datetime, datetime
from airflow.utils.types import DagRunType
from tests.models import TEST_DAGS_FOLDER
from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.mock_operators import MockOperator
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
pytestmark = pytest.mark.db_test
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = "unit_test_dag"
TEST_TASK_ID = "time_sensor_check"
TEST_TASK_ID_ALTERNATE = "time_sensor_check_alternate"
TEST_TASK_GROUP_ID = "time_sensor_group_id"
DEV_NULL = "/dev/null"
TASK_ID = "external_task_sensor_check"
EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on
EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is waiting on
@pytest.fixture(autouse=True)
def clean_db():
clear_db_runs()
@pytest.fixture
def dag_zip_maker(testing_dag_bundle):
class DagZipMaker:
def __call__(self, *dag_files):
self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(), dag_file]) for dag_file in dag_files]
dag_files_hash = md5("".join(self.__dag_files).encode()).hexdigest()
self.__tmp_dir = os.sep.join([tempfile.tempdir, dag_files_hash])
self.__zip_file_name = os.sep.join([self.__tmp_dir, f"{dag_files_hash}.zip"])
if not os.path.exists(self.__tmp_dir):
os.mkdir(self.__tmp_dir)
return self
def __enter__(self):
with zipfile.ZipFile(self.__zip_file_name, "x") as zf:
for dag_file in self.__dag_files:
zf.write(dag_file, os.path.basename(dag_file))
dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False)
dagbag.sync_to_db("testing", None)
return dagbag
def __exit__(self, exc_type, exc_val, exc_tb):
os.unlink(self.__zip_file_name)
os.rmdir(self.__tmp_dir)
return DagZipMaker()
@pytest.mark.usefixtures("testing_dag_bundle")
class TestExternalTaskSensor:
def setup_method(self):
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, schedule=None, default_args=self.args)
self.dag_run_id = DagRunType.MANUAL.generate_run_id(suffix=DEFAULT_DATE.isoformat())
def add_time_sensor(self, task_id=TEST_TASK_ID):
op = TimeSensor(task_id=task_id, target_time=time(0), dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def add_fake_task_group(self, target_states=None):
target_states = [State.SUCCESS] * 2 if target_states is None else target_states
with self.dag as dag:
with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group:
_ = [EmptyOperator(task_id=f"task{i}") for i in range(len(target_states))]
dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="test_bundle")
for idx, task in enumerate(task_group):
ti = TaskInstance(task=task, run_id=self.dag_run_id)
ti.run(ignore_ti_state=True, mark_success=True)
ti.set_state(target_states[idx])
def add_fake_task_group_with_dynamic_tasks(self, target_state=State.SUCCESS):
map_indexes = range(5)
with self.dag as dag:
with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group:
@task_deco
def fake_task():
pass
@task_deco
def fake_mapped_task(x: int):
return x
fake_task()
fake_mapped_task.expand(x=list(map_indexes))
dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="test_bundle")
for task in task_group:
if task.task_id == "fake_mapped_task":
for map_index in map_indexes:
ti = TaskInstance(task=task, run_id=self.dag_run_id, map_index=map_index)
ti.run(ignore_ti_state=True, mark_success=True)
ti.set_state(target_state)
else:
ti = TaskInstance(task=task, run_id=self.dag_run_id)
ti.run(ignore_ti_state=True, mark_success=True)
ti.set_state(target_state)
def test_external_task_sensor(self):
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_multiple_task_ids(self):
self.add_time_sensor(task_id=TEST_TASK_ID)
self.add_time_sensor(task_id=TEST_TASK_ID_ALTERNATE)
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check_task_ids",
external_dag_id=TEST_DAG_ID,
external_task_ids=[TEST_TASK_ID, TEST_TASK_ID_ALTERNATE],
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_with_task_group(self):
self.add_time_sensor()
self.add_fake_task_group()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_task_group",
external_dag_id=TEST_DAG_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_raise_with_external_task_sensor_task_id_and_task_ids(self):
with pytest.raises(ValueError) as ctx:
ExternalTaskSensor(
task_id="test_external_task_sensor_task_id_with_task_ids_failed_status",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
external_task_ids=TEST_TASK_ID,
dag=self.dag,
)
assert (
str(ctx.value) == "Only one of `external_task_id` or `external_task_ids` may "
"be provided to ExternalTaskSensor; "
"use external_task_id or external_task_ids or external_task_group_id."
)
def test_raise_with_external_task_sensor_task_group_and_task_id(self):
with pytest.raises(ValueError) as ctx:
ExternalTaskSensor(
task_id="test_external_task_sensor_task_group_with_task_id_failed_status",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
dag=self.dag,
)
assert (
str(ctx.value) == "Only one of `external_task_group_id` or `external_task_ids` may "
"be provided to ExternalTaskSensor; "
"use external_task_id or external_task_ids or external_task_group_id."
)
def test_raise_with_external_task_sensor_task_group_and_task_ids(self):
with pytest.raises(ValueError) as ctx:
ExternalTaskSensor(
task_id="test_external_task_sensor_task_group_with_task_ids_failed_status",
external_dag_id=TEST_DAG_ID,
external_task_ids=TEST_TASK_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
dag=self.dag,
)
assert (
str(ctx.value) == "Only one of `external_task_group_id` or `external_task_ids` may "
"be provided to ExternalTaskSensor; "
"use external_task_id or external_task_ids or external_task_group_id."
)
# by default i.e. check_existence=False, if task_group doesn't exist, the sensor will run till timeout,
# this behaviour is similar to external_task_id doesn't exists
def test_external_task_group_not_exists_without_check_existence(self):
self.add_time_sensor()
self.add_fake_task_group()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_group_id="fake-task-group",
timeout=0.001,
dag=self.dag,
poke_interval=0.1,
)
with pytest.raises(AirflowException, match="Sensor has timed out"):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_group_sensor_success(self):
self.add_time_sensor()
self.add_fake_task_group()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
failed_states=[State.FAILED],
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_group_sensor_failed_states(self):
ti_states = [State.FAILED, State.FAILED]
self.add_time_sensor()
self.add_fake_task_group(ti_states)
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
failed_states=[State.FAILED],
dag=self.dag,
)
with pytest.raises(
AirflowException,
match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.",
):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_catch_overlap_allowed_failed_state(self):
with pytest.raises(AirflowException):
ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=[State.SUCCESS],
failed_states=[State.SUCCESS],
dag=self.dag,
)
def test_external_task_sensor_wrong_failed_states(self):
with pytest.raises(ValueError):
ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
failed_states=["invalid_state"],
dag=self.dag,
)
def test_external_task_sensor_failed_states(self):
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
failed_states=["failed"],
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_failed_states_as_success(self, caplog):
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=["failed"],
failed_states=["success"],
dag=self.dag,
)
error_message = rf"Some of the external tasks \['{TEST_TASK_ID}'\] in DAG {TEST_DAG_ID} failed\."
with caplog.at_level(logging.INFO, logger=op.log.name):
caplog.clear()
with pytest.raises(AirflowException, match=error_message):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert (
f"Poking for tasks ['{TEST_TASK_ID}'] in dag {TEST_DAG_ID} on {DEFAULT_DATE.isoformat()} ... "
) in caplog.messages
def test_external_task_sensor_soft_fail_failed_states_as_skipped(self):
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=[State.FAILED],
failed_states=[State.SUCCESS],
soft_fail=True,
dag=self.dag,
)
# when
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
# then
session = settings.Session()
TI = TaskInstance
task_instances: list[TI] = session.query(TI).filter(TI.task_id == op.task_id).all()
assert len(task_instances) == 1, "Unexpected number of task instances"
assert task_instances[0].state == State.SKIPPED, "Unexpected external task state"
def test_external_task_sensor_skipped_states_as_skipped(self, session):
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=[State.FAILED],
skipped_states=[State.SUCCESS],
dag=self.dag,
)
# when
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
# then
TI = TaskInstance
task_instances: list[TI] = session.query(TI).filter(TI.task_id == op.task_id).all()
assert len(task_instances) == 1, "Unexpected number of task instances"
assert task_instances[0].state == State.SKIPPED, "Unexpected external task state"
def test_external_task_sensor_external_task_id_param(self, caplog):
"""Test external_task_ids is set properly when external_task_id is passed as a template"""
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id="{{ params.dag_id }}",
external_task_id="{{ params.task_id }}",
params={"dag_id": TEST_DAG_ID, "task_id": TEST_TASK_ID},
dag=self.dag,
)
with caplog.at_level(logging.INFO, logger=op.log.name):
caplog.clear()
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert (
f"Poking for tasks ['{TEST_TASK_ID}'] "
f"in dag {TEST_DAG_ID} on {DEFAULT_DATE.isoformat()} ... "
) in caplog.messages
def test_external_task_sensor_external_task_ids_param(self, caplog):
"""Test external_task_ids rendering when a template is passed."""
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id="{{ params.dag_id }}",
external_task_ids=["{{ params.task_id }}"],
params={"dag_id": TEST_DAG_ID, "task_id": TEST_TASK_ID},
dag=self.dag,
)
with caplog.at_level(logging.INFO, logger=op.log.name):
caplog.clear()
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert (
f"Poking for tasks ['{TEST_TASK_ID}'] "
f"in dag {TEST_DAG_ID} on {DEFAULT_DATE.isoformat()} ... "
) in caplog.messages
def test_external_task_sensor_failed_states_as_success_mulitple_task_ids(self, caplog):
self.add_time_sensor(task_id=TEST_TASK_ID)
self.add_time_sensor(task_id=TEST_TASK_ID_ALTERNATE)
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check_task_ids",
external_dag_id=TEST_DAG_ID,
external_task_ids=[TEST_TASK_ID, TEST_TASK_ID_ALTERNATE],
allowed_states=["failed"],
failed_states=["success"],
dag=self.dag,
)
error_message = (
rf"Some of the external tasks \['{TEST_TASK_ID}'\, \'{TEST_TASK_ID_ALTERNATE}\'] "
rf"in DAG {TEST_DAG_ID} failed\."
)
with caplog.at_level(logging.INFO, logger=op.log.name):
caplog.clear()
with pytest.raises(AirflowException, match=error_message):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert (
f"Poking for tasks ['{TEST_TASK_ID}', '{TEST_TASK_ID_ALTERNATE}'] "
f"in dag unit_test_dag on {DEFAULT_DATE.isoformat()} ... "
) in caplog.messages
def test_external_dag_sensor(self, dag_maker):
with dag_maker("other_dag", default_args=self.args, end_date=DEFAULT_DATE, schedule="@once"):
pass
dag_maker.create_dagrun(state=DagRunState.SUCCESS)
op = ExternalTaskSensor(
task_id="test_external_dag_sensor_check",
external_dag_id="other_dag",
external_task_id=None,
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_dag_sensor_log(self, caplog, dag_maker):
with dag_maker("other_dag", default_args=self.args, end_date=DEFAULT_DATE, schedule="@once"):
pass
dag_maker.create_dagrun(state=DagRunState.SUCCESS)
op = ExternalTaskSensor(
task_id="test_external_dag_sensor_check",
external_dag_id="other_dag",
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert (f"Poking for DAG 'other_dag' on {DEFAULT_DATE.isoformat()} ... ") in caplog.messages
def test_external_dag_sensor_soft_fail_as_skipped(self, dag_maker, session):
with dag_maker("other_dag", default_args=self.args, end_date=DEFAULT_DATE, schedule="@once"):
pass
dag_maker.create_dagrun(state=DagRunState.SUCCESS)
op = ExternalTaskSensor(
task_id="test_external_dag_sensor_check",
external_dag_id="other_dag",
external_task_id=None,
allowed_states=[State.FAILED],
failed_states=[State.SUCCESS],
soft_fail=True,
dag=self.dag,
)
# when
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
# then
TI = TaskInstance
task_instances: list[TI] = session.query(TI).filter(TI.task_id == op.task_id).all()
assert len(task_instances) == 1, "Unexpected number of task instances"
assert task_instances[0].state == State.SKIPPED, "Unexpected external task state"
def test_external_task_sensor_fn_multiple_logical_dates(self):
bash_command_code = """
{% set s=logical_date.time().second %}
echo "second is {{ s }}"
if [[ $(( {{ s }} % 60 )) == 1 ]]
then
exit 1
fi
exit 0
"""
dag_external_id = TEST_DAG_ID + "_external"
dag_external = DAG(dag_external_id, default_args=self.args, schedule=timedelta(seconds=1))
task_external_with_failure = BashOperator(
task_id="task_external_with_failure", bash_command=bash_command_code, retries=0, dag=dag_external
)
task_external_without_failure = EmptyOperator(
task_id="task_external_without_failure", retries=0, dag=dag_external
)
task_external_without_failure.run(
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(seconds=1), ignore_ti_state=True
)
session = settings.Session()
TI = TaskInstance
try:
task_external_with_failure.run(
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(seconds=1), ignore_ti_state=True
)
# The test_with_failure task is excepted to fail
# once per minute (the run on the first second of
# each minute).
except Exception as e:
failed_tis = (
session.query(TI)
.filter(
TI.dag_id == dag_external_id,
TI.state == State.FAILED,
TI.logical_date == DEFAULT_DATE + timedelta(seconds=1),
)
.all()
)
if len(failed_tis) == 1 and failed_tis[0].task_id == "task_external_with_failure":
pass
else:
raise e
dag_id = TEST_DAG_ID
dag = DAG(dag_id, default_args=self.args, schedule=timedelta(minutes=1))
task_without_failure = ExternalTaskSensor(
task_id="task_without_failure",
external_dag_id=dag_external_id,
external_task_id="task_external_without_failure",
execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(2)],
allowed_states=["success"],
retries=0,
timeout=1,
poke_interval=1,
dag=dag,
)
task_with_failure = ExternalTaskSensor(
task_id="task_with_failure",
external_dag_id=dag_external_id,
external_task_id="task_external_with_failure",
execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(2)],
allowed_states=["success"],
retries=0,
timeout=1,
poke_interval=1,
dag=dag,
)
task_without_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
with pytest.raises(AirflowSensorTimeout):
task_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
# Test to ensure that if one task in a chain of tasks fails, the
# ExternalTaskSensor will also report a failure and return without
# waiting for a timeout.
task_chain_with_failure = ExternalTaskSensor(
task_id="task_chain_with_failure",
external_dag_id=dag_external_id,
external_task_id="task_external_with_failure",
execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(3)],
allowed_states=["success"],
failed_states=["failed"],
retries=0,
timeout=5,
poke_interval=1,
dag=dag,
)
# We need to test for an AirflowException explicitly since
# AirflowSensorTimeout is a subclass that will be raised if this does
# not execute properly.
with pytest.raises(AirflowException) as ex_ctx:
task_chain_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert type(ex_ctx.value) is AirflowException
def test_external_task_sensor_delta(self):
self.add_time_sensor()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check_delta",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_delta=timedelta(0),
allowed_states=["success"],
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_fn(self):
self.add_time_sensor()
# check that the execution_fn works
op1 = ExternalTaskSensor(
task_id="test_external_task_sensor_check_delta_1",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_date_fn=lambda dt: dt + timedelta(0),
allowed_states=["success"],
dag=self.dag,
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
# double check that the execution is being called by failing the test
op2 = ExternalTaskSensor(
task_id="test_external_task_sensor_check_delta_2",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_date_fn=lambda dt: dt + timedelta(days=1),
allowed_states=["success"],
timeout=1,
poke_interval=1,
dag=self.dag,
)
with pytest.raises(AirflowSensorTimeout):
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_fn_multiple_args(self):
"""Check this task sensor passes multiple args with full context. If no failure, means clean run."""
self.add_time_sensor()
def my_func(dt, context):
assert context["logical_date"] == dt
return dt + timedelta(0)
op1 = ExternalTaskSensor(
task_id="test_external_task_sensor_multiple_arg_fn",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_date_fn=my_func,
allowed_states=["success"],
dag=self.dag,
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_fn_kwargs(self):
"""Check this task sensor passes multiple args with full context. If no failure, means clean run."""
self.add_time_sensor()
def my_func(dt, ds_nodash):
assert ds_nodash == dt.strftime("%Y%m%d")
return dt + timedelta(0)
op1 = ExternalTaskSensor(
task_id="test_external_task_sensor_fn_kwargs",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_date_fn=my_func,
allowed_states=["success"],
dag=self.dag,
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_error_delta_and_fn(self):
self.add_time_sensor()
# Test that providing execution_delta and a function raises an error
with pytest.raises(ValueError):
ExternalTaskSensor(
task_id="test_external_task_sensor_check_delta",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
execution_delta=timedelta(0),
execution_date_fn=lambda dt: dt,
allowed_states=["success"],
dag=self.dag,
)
def test_external_task_sensor_error_task_id_and_task_ids(self):
self.add_time_sensor()
# Test that providing execution_delta and a function raises an error
with pytest.raises(ValueError):
ExternalTaskSensor(
task_id="test_external_task_sensor_task_id_and_task_ids",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
external_task_ids=[TEST_TASK_ID],
allowed_states=["success"],
dag=self.dag,
)
def test_external_task_sensor_with_xcom_arg_does_not_fail_on_init(self):
self.add_time_sensor()
op1 = MockOperator(task_id="op1", dag=self.dag)
op2 = ExternalTaskSensor(
task_id="test_external_task_sensor_with_xcom_arg_does_not_fail_on_init",
external_dag_id=TEST_DAG_ID,
external_task_ids=XComArg(op1),
allowed_states=["success"],
dag=self.dag,
)
assert isinstance(op2.external_task_ids, XComArg)
def test_catch_duplicate_task_ids(self):
self.add_time_sensor()
op1 = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids",
external_dag_id=TEST_DAG_ID,
external_task_ids=[TEST_TASK_ID, TEST_TASK_ID],
allowed_states=["success"],
dag=self.dag,
)
with pytest.raises(ValueError):
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_catch_duplicate_task_ids_with_xcom_arg(self):
self.add_time_sensor()
op1 = PythonOperator(
python_callable=lambda: ["dupe_value", "dupe_value"],
task_id="op1",
do_xcom_push=True,
dag=self.dag,
)
op2 = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids_with_xcom_arg",
external_dag_id=TEST_DAG_ID,
external_task_ids=XComArg(op1),
allowed_states=["success"],
dag=self.dag,
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with pytest.raises(ValueError):
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_catch_duplicate_task_ids_with_multiple_xcom_args(self):
self.add_time_sensor()
op1 = PythonOperator(
python_callable=lambda: "value",
task_id="op1",
do_xcom_push=True,
dag=self.dag,
)
op2 = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids_with_xcom_arg",
external_dag_id=TEST_DAG_ID,
external_task_ids=[XComArg(op1), XComArg(op1)],
allowed_states=["success"],
dag=self.dag,
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with pytest.raises(ValueError):
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_catch_invalid_allowed_states(self):
with pytest.raises(ValueError):
ExternalTaskSensor(
task_id="test_external_task_sensor_check_1",
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=["invalid_state"],
dag=self.dag,
)
with pytest.raises(ValueError):
ExternalTaskSensor(
task_id="test_external_task_sensor_check_2",
external_dag_id=TEST_DAG_ID,
external_task_id=None,
allowed_states=["invalid_state"],
dag=self.dag,
)
def test_external_task_sensor_waits_for_task_check_existence(self):
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id="example_bash_operator",
external_task_id="non-existing-task",
check_existence=True,
dag=self.dag,
)
with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_waits_for_dag_check_existence(self):
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id="non-existing-dag",
external_task_id=None,
check_existence=True,
dag=self.dag,
)
with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_group_with_mapped_tasks_sensor_success(self):
self.add_time_sensor()
self.add_fake_task_group_with_dynamic_tasks()
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
failed_states=[State.FAILED],
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_group_with_mapped_tasks_failed_states(self):
self.add_time_sensor()
self.add_fake_task_group_with_dynamic_tasks(State.FAILED)
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
failed_states=[State.FAILED],
dag=self.dag,
)
with pytest.raises(
AirflowException,
match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.",
):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_group_when_there_is_no_TIs(self):
"""Test that the sensor does not fail when there are no TIs to check."""
self.add_time_sensor()
self.add_fake_task_group_with_dynamic_tasks(State.FAILED)
op = ExternalTaskSensor(
task_id="test_external_task_sensor_check",
external_dag_id=TEST_DAG_ID,
external_task_group_id=TEST_TASK_GROUP_ID,
failed_states=[State.FAILED],
dag=self.dag,
poke_interval=1,
timeout=3,
)
with pytest.raises(AirflowSensorTimeout):
op.run(
start_date=DEFAULT_DATE + timedelta(hours=1),
end_date=DEFAULT_DATE + timedelta(hours=1),
ignore_ti_state=True,
)
@pytest.mark.parametrize(
"kwargs, expected_message",
(
(
{
"external_task_ids": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE],
"failed_states": [State.FAILED],
},
f"Some of the external tasks {re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}"
f" in DAG {TEST_DAG_ID} failed.",
),
(
{
"external_task_group_id": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE],
"failed_states": [State.FAILED],
},
f"The external task_group '{re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}'"
f" in DAG '{TEST_DAG_ID}' failed.",
),
(
{"failed_states": [State.FAILED]},
f"The external DAG {TEST_DAG_ID} failed.",
),
),
)
@pytest.mark.parametrize(
"soft_fail, expected_exception",
(
(
False,
AirflowException,
),
(
True,
AirflowSkipException,
),
),
)
@mock.patch("airflow.providers.standard.sensors.external_task.ExternalTaskSensor.get_count")
@mock.patch("airflow.providers.standard.sensors.external_task.ExternalTaskSensor._get_dttm_filter")
def test_fail_poke(
self, _get_dttm_filter, get_count, soft_fail, expected_exception, kwargs, expected_message
):
_get_dttm_filter.return_value = []
get_count.return_value = 1
op = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids",
external_dag_id=TEST_DAG_ID,
allowed_states=["success"],
dag=self.dag,
soft_fail=soft_fail,
deferrable=False,
**kwargs,
)
with pytest.raises(expected_exception, match=expected_message):
op.execute(context={})
@pytest.mark.parametrize(
"response_get_current, response_exists, kwargs, expected_message",
(
(None, None, {}, f"The external DAG {TEST_DAG_ID} does not exist."),
(
DAG(dag_id="test", schedule=None),
False,
{},
f"The external DAG {TEST_DAG_ID} was deleted.",
),
(
DAG(dag_id="test", schedule=None),
True,
{"external_task_ids": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]},
f"The external task {TEST_TASK_ID} in DAG {TEST_DAG_ID} does not exist.",
),
(
DAG(dag_id="test", schedule=None),
True,
{"external_task_group_id": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]},
f"The external task group '{re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}'"
f" in DAG '{TEST_DAG_ID}' does not exist.",
),
),
)
@pytest.mark.parametrize(
"soft_fail, expected_exception",
(
(
False,
AirflowException,
),
(
True,
AirflowException,
),
),
)
@mock.patch("airflow.providers.standard.sensors.external_task.ExternalTaskSensor._get_dttm_filter")
@mock.patch("airflow.models.dagbag.DagBag.get_dag")
@mock.patch("os.path.exists")
@mock.patch("airflow.models.dag.DagModel.get_current")
def test_fail__check_for_existence(
self,
get_current,
exists,
get_dag,
_get_dttm_filter,
soft_fail,
expected_exception,
response_get_current,
response_exists,
kwargs,
expected_message,
):
_get_dttm_filter.return_value = []
get_current.return_value = response_get_current
exists.return_value = response_exists
get_dag_response = mock.MagicMock()
get_dag.return_value = get_dag_response
get_dag_response.has_task.return_value = False
get_dag_response.has_task_group.return_value = False
op = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids",
external_dag_id=TEST_DAG_ID,
allowed_states=["success"],
dag=self.dag,
soft_fail=soft_fail,
check_existence=True,
**kwargs,
)
with pytest.raises(expected_exception, match=expected_message):
op.execute(context={})
class TestExternalTaskAsyncSensor:
TASK_ID = "external_task_sensor_check"
EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on
EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is waiting on
def test_defer_and_fire_task_state_trigger(self):
"""
Asserts that a task is deferred and TaskStateTrigger will be fired
when the ExternalTaskAsyncSensor is provided with all required arguments
(i.e. including the external_task_id).
"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)
with pytest.raises(TaskDeferred) as exc:
sensor.execute(context=mock.MagicMock())
assert isinstance(exc.value.trigger, WorkflowTrigger), "Trigger is not a WorkflowTrigger"
def test_defer_and_fire_failed_state_trigger(self):
"""Tests that an AirflowException is raised in case of error event"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)
with pytest.raises(AirflowException):
sensor.execute_complete(
context=mock.MagicMock(), event={"status": "error", "message": "test failure message"}
)
def test_defer_and_fire_timeout_state_trigger(self):
"""Tests that an AirflowException is raised in case of timeout event"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)
with pytest.raises(AirflowException):
sensor.execute_complete(
context=mock.MagicMock(),
event={"status": "timeout", "message": "Dag was not started within 1 minute, assuming fail."},
)
def test_defer_execute_check_correct_logging(self):
"""Asserts that logging occurs as expected"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)
with mock.patch.object(sensor.log, "info") as mock_log_info:
sensor.execute_complete(
context=mock.MagicMock(),
event={"status": "success"},
)
mock_log_info.assert_called_with("External tasks %s has executed successfully.", [EXTERNAL_TASK_ID])
def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
with dag_zip_maker("test_external_task_sensor_check_existense.py") as dagbag:
with create_session() as session:
dag = dagbag.dags["test_external_task_sensor_check_existence"]
op = dag.tasks[0]
op._check_for_existence(session)
@pytest.mark.parametrize(
argnames=["external_dag_id", "external_task_id", "expected_external_dag_id", "expected_external_task_id"],
argvalues=[
("dag_test", "task_test", "dag_test", "task_test"),
("dag_{{ ds }}", "task_{{ ds }}", f"dag_{DEFAULT_DATE.date()}", f"task_{DEFAULT_DATE.date()}"),
],
ids=["not_templated", "templated"],
)
def test_external_task_sensor_extra_link(
external_dag_id,
external_task_id,
expected_external_dag_id,
expected_external_task_id,
create_task_instance_of_operator,
):
ti = create_task_instance_of_operator(
ExternalTaskSensor,
dag_id="external_task_sensor_extra_links_dag",
logical_date=DEFAULT_DATE,
task_id="external_task_sensor_extra_links_task",
external_dag_id=external_dag_id,
external_task_id=external_task_id,
)
ti.render_templates()
assert ti.task.external_dag_id == expected_external_dag_id
assert ti.task.external_task_id == expected_external_task_id
assert ti.task.external_task_ids == [expected_external_task_id]
url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key)
assert f"/dags/{expected_external_dag_id}/runs" in url
class TestExternalTaskMarker:
def test_serialized_fields(self):
assert {"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields())
def test_serialized_external_task_marker(self):
dag = DAG("test_serialized_external_task_marker", schedule=None, start_date=DEFAULT_DATE)
task = ExternalTaskMarker(
task_id="parent_task",
external_dag_id="external_task_marker_child",
external_task_id="child_task1",
dag=dag,
)
serialized_op = SerializedBaseOperator.serialize_operator(task)
deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
assert deserialized_op.task_type == "ExternalTaskMarker"
assert getattr(deserialized_op, "external_dag_id") == "external_task_marker_child"
assert getattr(deserialized_op, "external_task_id") == "child_task1"
@pytest.fixture
def dag_bag_ext():
"""
Create a DagBag with DAGs looking like this. The dotted lines represent external dependencies
set up using ExternalTaskMarker and ExternalTaskSensor.
dag_0: task_a_0 >> task_b_0
|
|
dag_1: ---> task_a_1 >> task_b_1
|
|
dag_2: ---> task_a_2 >> task_b_2
|
|
dag_3: ---> task_a_3 >> task_b_3
"""
clear_db_runs()
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
dag_0 = DAG("dag_0", start_date=DEFAULT_DATE, schedule=None)
task_a_0 = EmptyOperator(task_id="task_a_0", dag=dag_0)
task_b_0 = ExternalTaskMarker(
task_id="task_b_0", external_dag_id="dag_1", external_task_id="task_a_1", recursion_depth=3, dag=dag_0
)
task_a_0 >> task_b_0
dag_1 = DAG("dag_1", start_date=DEFAULT_DATE, schedule=None)
task_a_1 = ExternalTaskSensor(
task_id="task_a_1", external_dag_id=dag_0.dag_id, external_task_id=task_b_0.task_id, dag=dag_1
)
task_b_1 = ExternalTaskMarker(
task_id="task_b_1", external_dag_id="dag_2", external_task_id="task_a_2", recursion_depth=2, dag=dag_1
)
task_a_1 >> task_b_1
dag_2 = DAG("dag_2", start_date=DEFAULT_DATE, schedule=None)
task_a_2 = ExternalTaskSensor(
task_id="task_a_2", external_dag_id=dag_1.dag_id, external_task_id=task_b_1.task_id, dag=dag_2
)
task_b_2 = ExternalTaskMarker(
task_id="task_b_2", external_dag_id="dag_3", external_task_id="task_a_3", recursion_depth=1, dag=dag_2
)
task_a_2 >> task_b_2
dag_3 = DAG("dag_3", start_date=DEFAULT_DATE, schedule=None)
task_a_3 = ExternalTaskSensor(
task_id="task_a_3", external_dag_id=dag_2.dag_id, external_task_id=task_b_2.task_id, dag=dag_3
)
task_b_3 = EmptyOperator(task_id="task_b_3", dag=dag_3)
task_a_3 >> task_b_3
for dag in [dag_0, dag_1, dag_2, dag_3]:
dag_bag.bag_dag(dag=dag)
yield dag_bag
clear_db_runs()
@pytest.fixture
def dag_bag_parent_child():
"""
Create a DagBag with two DAGs looking like this. task_1 of child_dag_1 on day 1 depends on
task_0 of parent_dag_0 on day 1. Therefore, when task_0 of parent_dag_0 on day 1 and day 2
are cleared, parent_dag_0 DagRuns need to be set to running on both days, but child_dag_1
only needs to be set to running on day 1.
day 1 day 2
parent_dag_0 task_0 task_0
|
|
v
child_dag_1 task_1 task_1
"""
clear_db_runs()
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
day_1 = DEFAULT_DATE
with DAG("parent_dag_0", start_date=day_1, schedule=None) as dag_0:
task_0 = ExternalTaskMarker(
task_id="task_0",
external_dag_id="child_dag_1",
external_task_id="task_1",
logical_date=day_1.isoformat(),
recursion_depth=3,
)
with DAG("child_dag_1", start_date=day_1, schedule=None) as dag_1:
ExternalTaskSensor(
task_id="task_1",
external_dag_id=dag_0.dag_id,
external_task_id=task_0.task_id,
execution_date_fn=lambda logical_date: day_1 if logical_date == day_1 else [],
mode="reschedule",
)
for dag in [dag_0, dag_1]:
dag_bag.bag_dag(dag=dag)
yield dag_bag
clear_db_runs()
@provide_session
def run_tasks(
dag_bag: DagBag,
logical_date=DEFAULT_DATE,
session=NEW_SESSION,
) -> tuple[dict[str, DagRun], dict[str, TaskInstance]]:
"""
Run all tasks in the DAGs in the given dag_bag. Return the TaskInstance objects as a dict
keyed by task_id.
"""
runs: dict[str, DagRun] = {}
tis: dict[str, TaskInstance] = {}
for dag in dag_bag.dags.values():
data_interval = DataInterval(coerce_datetime(logical_date), coerce_datetime(logical_date))
runs[dag.dag_id] = dagrun = dag.create_dagrun(
run_id=dag.timetable.generate_run_id(
run_type=DagRunType.MANUAL,
run_after=logical_date,
data_interval=data_interval,
),
logical_date=logical_date,
data_interval=data_interval,
run_after=logical_date,
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.TEST,
dag_version=None,
state=DagRunState.RUNNING,
start_date=logical_date,
session=session,
)
# we use sorting by task_id here because for the test DAG structure of ours
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
tasks = sorted(dagrun.task_instances, key=lambda ti: ti.task_id)
for ti in tasks:
ti.refresh_from_task(dag.get_task(ti.task_id))
tis[ti.task_id] = ti
ti.run(session=session)
session.flush()
session.merge(ti)
assert_ti_state_equal(ti, State.SUCCESS)
return runs, tis
def assert_ti_state_equal(task_instance, state):
"""
Assert state of task_instances equals the given state.
"""
task_instance.refresh_from_db()
assert task_instance.state == state
@provide_session
def clear_tasks(
dag_bag,
dag,
task,
session,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE,
dry_run=False,
):
"""
Clear the task and its downstream tasks recursively for the dag in the given dagbag.
"""
partial: DAG = dag.partial_subset(task_ids=[task.task_id], include_downstream=True)
return partial.clear(
start_date=start_date,
end_date=end_date,
dag_bag=dag_bag,
dry_run=dry_run,
session=session,
)
def test_external_task_marker_transitive(dag_bag_ext):
"""
Test clearing tasks across DAGs.
"""
_, tis = run_tasks(dag_bag_ext)
dag_0 = dag_bag_ext.get_dag("dag_0")
task_a_0 = dag_0.get_task("task_a_0")
clear_tasks(dag_bag_ext, dag_0, task_a_0)
ti_a_0 = tis["task_a_0"]
ti_b_3 = tis["task_b_3"]
assert_ti_state_equal(ti_a_0, State.NONE)
assert_ti_state_equal(ti_b_3, State.NONE)
@provide_session
def test_external_task_marker_clear_activate(dag_bag_parent_child, session):
"""
Test clearing tasks across DAGs and make sure the right DagRuns are activated.
"""
dag_bag = dag_bag_parent_child
day_1 = DEFAULT_DATE
day_2 = DEFAULT_DATE + timedelta(days=1)
run_tasks(dag_bag, logical_date=day_1)
run_tasks(dag_bag, logical_date=day_2)
from sqlalchemy import select
run_ids = []
# Assert that dagruns of all the affected dags are set to SUCCESS before tasks are cleared.
for dag, logical_date in itertools.product(dag_bag.dags.values(), [day_1, day_2]):
run_id = (
select(DagRun.run_id)
.where(DagRun.logical_date == logical_date)
.order_by(DagRun.id.desc())
.limit(1)
)
run_ids.append(run_id)
dagrun = dag.get_dagrun(
run_id=run_id,
session=session,
)
dagrun.set_state(State.SUCCESS)
session.flush()
dag_0 = dag_bag.get_dag("parent_dag_0")
task_0 = dag_0.get_task("task_0")
clear_tasks(dag_bag, dag_0, task_0, start_date=day_1, end_date=day_2, session=session)
# Assert that dagruns of all the affected dags are set to QUEUED after tasks are cleared.
# Unaffected dagruns should be left as SUCCESS.
dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[0], session=session)
dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[1], session=session)
dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[2], session=session)
dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[3], session=session)
assert dagrun_0_1.state == State.QUEUED
assert dagrun_0_2.state == State.QUEUED
assert dagrun_1_1.state == State.QUEUED
assert dagrun_1_2.state == State.SUCCESS
def test_external_task_marker_future(dag_bag_ext):
"""
Test clearing tasks with no end_date. This is the case when users clear tasks with
Future, Downstream and Recursive selected.
"""
date_0 = DEFAULT_DATE
date_1 = DEFAULT_DATE + timedelta(days=1)
_, tis_date_0 = run_tasks(dag_bag_ext, logical_date=date_0)
_, tis_date_1 = run_tasks(dag_bag_ext, logical_date=date_1)
dag_0 = dag_bag_ext.get_dag("dag_0")
task_a_0 = dag_0.get_task("task_a_0")
# This should clear all tasks on dag_0 to dag_3 on both date_0 and date_1
clear_tasks(dag_bag_ext, dag_0, task_a_0, end_date=None)
ti_a_0_date_0 = tis_date_0["task_a_0"]
ti_b_3_date_0 = tis_date_0["task_b_3"]
ti_b_3_date_1 = tis_date_1["task_b_3"]
assert_ti_state_equal(ti_a_0_date_0, State.NONE)
assert_ti_state_equal(ti_b_3_date_0, State.NONE)
assert_ti_state_equal(ti_b_3_date_1, State.NONE)
def test_external_task_marker_exception(dag_bag_ext):
"""
Clearing across multiple DAGs should raise AirflowException if more levels are being cleared
than allowed by the recursion_depth of the first ExternalTaskMarker being cleared.
"""
run_tasks(dag_bag_ext)
dag_0 = dag_bag_ext.get_dag("dag_0")
task_a_0 = dag_0.get_task("task_a_0")
task_b_0 = dag_0.get_task("task_b_0")
task_b_0.recursion_depth = 2
with pytest.raises(AirflowException, match="Maximum recursion depth 2"):
clear_tasks(dag_bag_ext, dag_0, task_a_0)
@pytest.fixture
def dag_bag_cyclic():
"""
Create a DagBag with DAGs having cyclic dependencies set up by ExternalTaskMarker and
ExternalTaskSensor.
dag_0: task_a_0 >> task_b_0
^ |
| |
dag_1: | ---> task_a_1 >> task_b_1
| ^
| |
dag_n: | ---> task_a_n >> task_b_n
| |
-----------------------------------------------------
"""
def _factory(depth: int) -> DagBag:
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
dags = []
with DAG("dag_0", start_date=DEFAULT_DATE, schedule=None) as dag:
dags.append(dag)
task_a_0 = EmptyOperator(task_id="task_a_0")
task_b_0 = ExternalTaskMarker(
task_id="task_b_0", external_dag_id="dag_1", external_task_id="task_a_1", recursion_depth=3
)
task_a_0 >> task_b_0
for n in range(1, depth):
with DAG(f"dag_{n}", start_date=DEFAULT_DATE, schedule=None) as dag:
dags.append(dag)
task_a = ExternalTaskSensor(
task_id=f"task_a_{n}",
external_dag_id=f"dag_{n-1}",
external_task_id=f"task_b_{n-1}",
)
task_b = ExternalTaskMarker(
task_id=f"task_b_{n}",
external_dag_id=f"dag_{n+1}",
external_task_id=f"task_a_{n+1}",
recursion_depth=3,
)
task_a >> task_b
# Create the last dag which loops back
with DAG(f"dag_{depth}", start_date=DEFAULT_DATE, schedule=None) as dag:
dags.append(dag)
task_a = ExternalTaskSensor(
task_id=f"task_a_{depth}",
external_dag_id=f"dag_{depth-1}",
external_task_id=f"task_b_{depth-1}",
)
task_b = ExternalTaskMarker(
task_id=f"task_b_{depth}",
external_dag_id="dag_0",
external_task_id="task_a_0",
recursion_depth=2,
)
task_a >> task_b
for dag in dags:
dag_bag.bag_dag(dag=dag)
return dag_bag
return _factory
def test_external_task_marker_cyclic_deep(dag_bag_cyclic):
"""
Tests clearing across multiple DAGs that have cyclic dependencies. AirflowException should be
raised.
"""
dag_bag = dag_bag_cyclic(10)
run_tasks(dag_bag)
dag_0 = dag_bag.get_dag("dag_0")
task_a_0 = dag_0.get_task("task_a_0")
with pytest.raises(AirflowException, match="Maximum recursion depth 3"):
clear_tasks(dag_bag, dag_0, task_a_0)
def test_external_task_marker_cyclic_shallow(dag_bag_cyclic):
"""
Tests clearing across multiple DAGs that have cyclic dependencies shallower
than recursion_depth
"""
dag_bag = dag_bag_cyclic(2)
run_tasks(dag_bag)
dag_0 = dag_bag.get_dag("dag_0")
task_a_0 = dag_0.get_task("task_a_0")
tis = clear_tasks(dag_bag, dag_0, task_a_0, dry_run=True)
assert sorted((ti.dag_id, ti.task_id) for ti in tis) == [
("dag_0", "task_a_0"),
("dag_0", "task_b_0"),
("dag_1", "task_a_1"),
("dag_1", "task_b_1"),
("dag_2", "task_a_2"),
("dag_2", "task_b_2"),
]
@pytest.fixture
def dag_bag_multiple():
"""
Create a DagBag containing two DAGs, linked by multiple ExternalTaskMarker.
"""
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
daily_dag = DAG("daily_dag", start_date=DEFAULT_DATE, schedule="@daily")
agg_dag = DAG("agg_dag", start_date=DEFAULT_DATE, schedule="@daily")
dag_bag.bag_dag(dag=daily_dag)
dag_bag.bag_dag(dag=agg_dag)
daily_task = EmptyOperator(task_id="daily_tas", dag=daily_dag)
begin = EmptyOperator(task_id="begin", dag=agg_dag)
for i in range(8):
task = ExternalTaskMarker(
task_id=f"{daily_task.task_id}_{i}",
external_dag_id=daily_dag.dag_id,
external_task_id=daily_task.task_id,
logical_date=f"{{{{ macros.ds_add(ds, -1 * {i}) }}}}",
dag=agg_dag,
)
begin >> task
return dag_bag
def test_clear_multiple_external_task_marker(dag_bag_multiple):
"""
Test clearing a dag that has multiple ExternalTaskMarker.
"""
agg_dag = dag_bag_multiple.get_dag("agg_dag")
_, tis = run_tasks(dag_bag_multiple, logical_date=DEFAULT_DATE)
session = settings.Session()
try:
qry = session.query(TaskInstance).filter(
TaskInstance.state.is_(None), TaskInstance.dag_id.in_(dag_bag_multiple.dag_ids)
)
assert agg_dag.clear(dag_bag=dag_bag_multiple) == len(tis) == qry.count() == 10
finally:
session.close()
@pytest.fixture
def dag_bag_head_tail():
"""
Create a DagBag containing one DAG, with task "head" depending on task "tail" of the
previous logical_date.
20200501 20200502 20200510
+------+ +------+ +------+
| head | -->head | --> -->head |
| | | / | | | / / | | |
| v | / | v | / / | v |
| body | / | body | / ... / | body |
| | |/ | | |/ / | | |
| v / | v / / | v |
| tail/| | tail/| / | tail |
+------+ +------+ +------+
"""
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
with DAG("head_tail", start_date=DEFAULT_DATE, schedule="@daily") as dag:
head = ExternalTaskSensor(
task_id="head",
external_dag_id=dag.dag_id,
external_task_id="tail",
execution_delta=timedelta(days=1),
mode="reschedule",
)
body = EmptyOperator(task_id="body")
tail = ExternalTaskMarker(
task_id="tail",
external_dag_id=dag.dag_id,
external_task_id=head.task_id,
logical_date="{{ macros.ds_add(ds, 1) }}",
)
head >> body >> tail
dag_bag.bag_dag(dag=dag)
return dag_bag
@provide_session
def test_clear_overlapping_external_task_marker(dag_bag_head_tail, session):
dag: DAG = dag_bag_head_tail.get_dag("head_tail")
# "Run" 10 times.
for delta in range(10):
logical_date = DEFAULT_DATE + timedelta(days=delta)
dagrun = DagRun(
dag_id=dag.dag_id,
start_date=logical_date,
state=DagRunState.SUCCESS,
logical_date=logical_date,
run_type=DagRunType.MANUAL,
run_id=f"test_{delta}",
)
session.add(dagrun)
for task in dag.tasks:
ti = TaskInstance(task=task)
dagrun.task_instances.append(ti)
ti.state = TaskInstanceState.SUCCESS
session.flush()
assert dag.clear(start_date=DEFAULT_DATE, dag_bag=dag_bag_head_tail, session=session) == 30
@provide_session
def test_clear_overlapping_external_task_marker_with_end_date(dag_bag_head_tail, session):
dag: DAG = dag_bag_head_tail.get_dag("head_tail")
# "Run" 10 times.
for delta in range(10):
logical_date = DEFAULT_DATE + timedelta(days=delta)
dagrun = DagRun(
dag_id=dag.dag_id,
start_date=logical_date,
state=DagRunState.SUCCESS,
logical_date=logical_date,
run_type=DagRunType.MANUAL,
run_id=f"test_{delta}",
)
session.add(dagrun)
for task in dag.tasks:
ti = TaskInstance(task=task)
dagrun.task_instances.append(ti)
ti.state = TaskInstanceState.SUCCESS
session.flush()
assert (
dag.clear(
start_date=DEFAULT_DATE,
end_date=logical_date,
dag_bag=dag_bag_head_tail,
session=session,
)
== 30
)
@pytest.fixture
def dag_bag_head_tail_mapped_tasks():
"""
Create a DagBag containing one DAG, with task "head" depending on task "tail" of the
previous logical_date.
20200501 20200502 20200510
+------+ +------+ +------+
| head | -->head | --> -->head |
| | | / | | | / / | | |
| v | / | v | / / | v |
| body | / | body | / ... / | body |
| | |/ | | |/ / | | |
| v / | v / / | v |
| tail/| | tail/| / | tail |
+------+ +------+ +------+
"""
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
with DAG("head_tail", start_date=DEFAULT_DATE, schedule="@daily") as dag:
@task_deco
def dummy_task(x: int):
return x
head = ExternalTaskSensor(
task_id="head",
external_dag_id=dag.dag_id,
external_task_id="tail",
execution_delta=timedelta(days=1),
mode="reschedule",
)
body = dummy_task.expand(x=range(5))
tail = ExternalTaskMarker(
task_id="tail",
external_dag_id=dag.dag_id,
external_task_id=head.task_id,
logical_date="{{ macros.ds_add(ds, 1) }}",
)
head >> body >> tail
dag_bag.bag_dag(dag=dag)
return dag_bag
@provide_session
def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_mapped_tasks, session):
dag: DAG = dag_bag_head_tail_mapped_tasks.get_dag("head_tail")
# "Run" 10 times.
for delta in range(10):
logical_date = DEFAULT_DATE + timedelta(days=delta)
dagrun = DagRun(
dag_id=dag.dag_id,
start_date=logical_date,
state=DagRunState.SUCCESS,
logical_date=logical_date,
run_type=DagRunType.MANUAL,
run_id=f"test_{delta}",
)
session.add(dagrun)
for task in dag.tasks:
if task.task_id == "dummy_task":
for map_index in range(5):
ti = TaskInstance(task=task, run_id=dagrun.run_id, map_index=map_index)
ti.state = TaskInstanceState.SUCCESS
dagrun.task_instances.append(ti)
else:
ti = TaskInstance(task=task, run_id=dagrun.run_id)
ti.state = TaskInstanceState.SUCCESS
dagrun.task_instances.append(ti)
session.flush()
dag = dag.partial_subset(
task_ids=["head"],
include_downstream=True,
include_upstream=False,
)
task_ids = list(dag.task_dict)
assert (
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE,
dag_bag=dag_bag_head_tail_mapped_tasks,
session=session,
task_ids=task_ids,
)
== 70
)