| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import contextlib |
| import datetime |
| import operator |
| import os |
| import pathlib |
| import pickle |
| import signal |
| import sys |
| import urllib |
| from traceback import format_exception |
| from typing import cast |
| from unittest import mock |
| from unittest.mock import MagicMock, call, mock_open, patch |
| from uuid import uuid4 |
| |
| import pendulum |
| import pytest |
| import time_machine |
| from sqlalchemy import select |
| |
| from airflow import settings |
| from airflow.decorators import task, task_group |
| from airflow.example_dags.plugins.workday import AfterWorkdayTimetable |
| from airflow.exceptions import ( |
| AirflowException, |
| AirflowFailException, |
| AirflowRescheduleException, |
| AirflowSensorTimeout, |
| AirflowSkipException, |
| AirflowTaskTerminated, |
| UnmappableXComLengthPushed, |
| UnmappableXComTypePushed, |
| XComForMappingNotPushed, |
| ) |
| from airflow.models.connection import Connection |
| from airflow.models.dag import DAG |
| from airflow.models.dagbag import DagBag |
| from airflow.models.dagrun import DagRun |
| from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel |
| from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated |
| from airflow.models.param import process_params |
| from airflow.models.pool import Pool |
| from airflow.models.renderedtifields import RenderedTaskInstanceFields |
| from airflow.models.serialized_dag import SerializedDagModel |
| from airflow.models.taskfail import TaskFail |
| from airflow.models.taskinstance import ( |
| TaskInstance, |
| TaskInstance as TI, |
| TaskInstanceNote, |
| _get_private_try_number, |
| _get_try_number, |
| _run_finished_callback, |
| ) |
| from airflow.models.taskmap import TaskMap |
| from airflow.models.taskreschedule import TaskReschedule |
| from airflow.models.variable import Variable |
| from airflow.models.xcom import LazyXComSelectSequence, XCom |
| from airflow.operators.bash import BashOperator |
| from airflow.operators.empty import EmptyOperator |
| from airflow.operators.python import PythonOperator |
| from airflow.sensors.base import BaseSensorOperator |
| from airflow.sensors.python import PythonSensor |
| from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG |
| from airflow.settings import TIMEZONE |
| from airflow.stats import Stats |
| from airflow.ti_deps.dep_context import DepContext |
| from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS |
| from airflow.ti_deps.dependencies_states import RUNNABLE_STATES |
| from airflow.ti_deps.deps.base_ti_dep import TIDepStatus |
| from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep, _UpstreamTIStates |
| from airflow.utils import timezone |
| from airflow.utils.db import merge_conn |
| from airflow.utils.module_loading import qualname |
| from airflow.utils.session import create_session, provide_session |
| from airflow.utils.state import DagRunState, State, TaskInstanceState |
| from airflow.utils.task_group import TaskGroup |
| from airflow.utils.task_instance_session import set_current_task_instance_session |
| from airflow.utils.types import DagRunType |
| from airflow.utils.xcom import XCOM_RETURN_KEY |
| from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER |
| from tests.test_utils import db |
| from tests.test_utils.config import conf_vars |
| from tests.test_utils.db import clear_db_connections, clear_db_runs |
| from tests.test_utils.mock_operators import MockOperator |
| |
| pytestmark = pytest.mark.db_test |
| |
| |
| @pytest.fixture |
| def test_pool(): |
| with create_session() as session: |
| test_pool = Pool(pool="test_pool", slots=1, include_deferred=False) |
| session.add(test_pool) |
| session.flush() |
| yield test_pool |
| session.rollback() |
| |
| |
| @pytest.fixture |
| def task_reschedules_for_ti(): |
| def wrapper(ti): |
| with create_session() as session: |
| return session.scalars(TaskReschedule.stmt_for_task_instance(ti=ti, descending=False)).all() |
| |
| return wrapper |
| |
| |
| class CallbackWrapper: |
| task_id: str | None = None |
| dag_id: str | None = None |
| execution_date: datetime.datetime | None = None |
| task_state_in_callback: str | None = None |
| callback_ran = False |
| |
| def wrap_task_instance(self, ti): |
| self.task_id = ti.task_id |
| self.dag_id = ti.dag_id |
| self.execution_date = ti.execution_date |
| self.task_state_in_callback = "" |
| self.callback_ran = False |
| |
| def success_handler(self, context): |
| self.callback_ran = True |
| self.task_state_in_callback = context["ti"].state |
| |
| |
| class TestTaskInstance: |
| @staticmethod |
| def clean_db(): |
| db.clear_db_dags() |
| db.clear_db_pools() |
| db.clear_db_runs() |
| db.clear_db_task_fail() |
| db.clear_rendered_ti_fields() |
| db.clear_db_task_reschedule() |
| db.clear_db_datasets() |
| db.clear_db_xcom() |
| |
| def setup_method(self): |
| self.clean_db() |
| |
| # We don't want to store any code for (test) dags created in this file |
| with patch.object(settings, "STORE_DAG_CODE", False): |
| yield |
| |
| def teardown_method(self): |
| self.clean_db() |
| |
| def test_set_task_dates(self, dag_maker): |
| """ |
| Test that tasks properly take start/end dates from DAGs |
| """ |
| with dag_maker("dag", end_date=DEFAULT_DATE + datetime.timedelta(days=10)) as dag: |
| pass |
| |
| op1 = EmptyOperator(task_id="op_1") |
| |
| assert op1.start_date is None |
| assert op1.end_date is None |
| |
| # dag should assign its dates to op1 because op1 has no dates |
| dag.add_task(op1) |
| dag_maker.create_dagrun() |
| assert op1.start_date == dag.start_date |
| assert op1.end_date == dag.end_date |
| |
| op2 = EmptyOperator( |
| task_id="op_2", |
| start_date=DEFAULT_DATE - datetime.timedelta(days=1), |
| end_date=DEFAULT_DATE + datetime.timedelta(days=11), |
| ) |
| |
| # dag should assign its dates to op2 because they are more restrictive |
| dag.add_task(op2) |
| assert op2.start_date == dag.start_date |
| assert op2.end_date == dag.end_date |
| |
| op3 = EmptyOperator( |
| task_id="op_3", |
| start_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| end_date=DEFAULT_DATE + datetime.timedelta(days=9), |
| ) |
| # op3 should keep its dates because they are more restrictive |
| dag.add_task(op3) |
| assert op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1) |
| assert op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9) |
| |
| def test_current_state(self, create_task_instance, session): |
| ti = create_task_instance(session=session) |
| assert ti.current_state(session=session) is None |
| ti.run() |
| assert ti.current_state(session=session) == State.SUCCESS |
| |
| def test_set_dag(self, dag_maker): |
| """ |
| Test assigning Operators to Dags, including deferred assignment |
| """ |
| with dag_maker("dag") as dag: |
| pass |
| with dag_maker("dag2") as dag2: |
| pass |
| op = EmptyOperator(task_id="op_1") |
| |
| # no dag assigned |
| assert not op.has_dag() |
| with pytest.raises(AirflowException): |
| getattr(op, "dag") |
| |
| # no improper assignment |
| with pytest.raises(TypeError): |
| op.dag = 1 |
| |
| op.dag = dag |
| |
| # no reassignment |
| with pytest.raises(AirflowException): |
| op.dag = dag2 |
| |
| # but assigning the same dag is ok |
| op.dag = dag |
| |
| assert op.dag is dag |
| assert op in dag.tasks |
| |
| def test_infer_dag(self, create_dummy_dag): |
| op1 = EmptyOperator(task_id="test_op_1") |
| op2 = EmptyOperator(task_id="test_op_2") |
| |
| dag, op3 = create_dummy_dag(task_id="test_op_3") |
| |
| _, op4 = create_dummy_dag("dag2", task_id="test_op_4") |
| |
| # double check dags |
| assert [i.has_dag() for i in [op1, op2, op3, op4]] == [False, False, True, True] |
| |
| # can't combine operators with no dags |
| with pytest.raises(AirflowException): |
| op1.set_downstream(op2) |
| |
| # op2 should infer dag from op1 |
| op1.dag = dag |
| op1.set_downstream(op2) |
| assert op2.dag is dag |
| |
| # can't assign across multiple DAGs |
| with pytest.raises(AirflowException): |
| op1.set_downstream(op4) |
| with pytest.raises(AirflowException): |
| op1.set_downstream([op3, op4]) |
| |
| def test_bitshift_compose_operators(self, dag_maker): |
| with dag_maker("dag"): |
| op1 = EmptyOperator(task_id="test_op_1") |
| op2 = EmptyOperator(task_id="test_op_2") |
| op3 = EmptyOperator(task_id="test_op_3") |
| |
| op1 >> op2 << op3 |
| dag_maker.create_dagrun() |
| # op2 should be downstream of both |
| assert op2 in op1.downstream_list |
| assert op2 in op3.downstream_list |
| |
| def test_init_on_load(self, create_task_instance): |
| ti = create_task_instance() |
| # ensure log is correctly created for ORM ti |
| assert ti.log.name == "airflow.task" |
| assert not ti.test_mode |
| |
| @patch.object(DAG, "get_concurrency_reached") |
| def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_task_instance): |
| mock_concurrency_reached.return_value = True |
| |
| ti = create_task_instance( |
| dag_id="test_requeue_over_dag_concurrency", |
| task_id="test_requeue_over_dag_concurrency_op", |
| max_active_runs=1, |
| max_active_tasks=2, |
| dagrun_state=State.QUEUED, |
| ) |
| ti.run() |
| assert ti.state == State.NONE |
| |
| def test_requeue_over_max_active_tis_per_dag(self, create_task_instance): |
| ti = create_task_instance( |
| dag_id="test_requeue_over_max_active_tis_per_dag", |
| task_id="test_requeue_over_max_active_tis_per_dag_op", |
| max_active_tis_per_dag=0, |
| max_active_runs=1, |
| max_active_tasks=2, |
| dagrun_state=State.QUEUED, |
| ) |
| |
| ti.run() |
| assert ti.state == State.NONE |
| |
| def test_requeue_over_max_active_tis_per_dagrun(self, create_task_instance): |
| ti = create_task_instance( |
| dag_id="test_requeue_over_max_active_tis_per_dagrun", |
| task_id="test_requeue_over_max_active_tis_per_dagrun_op", |
| max_active_tis_per_dagrun=0, |
| max_active_runs=1, |
| max_active_tasks=2, |
| dagrun_state=State.QUEUED, |
| ) |
| |
| ti.run() |
| assert ti.state == State.NONE |
| |
| def test_requeue_over_pool_concurrency(self, create_task_instance, test_pool): |
| ti = create_task_instance( |
| dag_id="test_requeue_over_pool_concurrency", |
| task_id="test_requeue_over_pool_concurrency_op", |
| max_active_tis_per_dag=0, |
| max_active_runs=1, |
| max_active_tasks=2, |
| ) |
| with create_session() as session: |
| test_pool.slots = 0 |
| session.flush() |
| ti.run() |
| assert ti.state == State.NONE |
| |
| @pytest.mark.usefixtures("test_pool") |
| def test_not_requeue_non_requeueable_task_instance(self, dag_maker): |
| # Use BaseSensorOperator because sensor got |
| # one additional DEP in BaseSensorOperator().deps |
| with dag_maker(dag_id="test_not_requeue_non_requeueable_task_instance"): |
| task = BaseSensorOperator( |
| task_id="test_not_requeue_non_requeueable_task_instance_op", |
| pool="test_pool", |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| ti.state = State.QUEUED |
| with create_session() as session: |
| session.add(ti) |
| session.commit() |
| |
| all_deps = RUNNING_DEPS | task.deps |
| all_non_requeueable_deps = all_deps - REQUEUEABLE_DEPS |
| patch_dict = {} |
| for dep in all_non_requeueable_deps: |
| class_name = dep.__class__.__name__ |
| dep_patch = patch(f"{dep.__module__}.{class_name}.{dep._get_dep_statuses.__name__}") |
| method_patch = dep_patch.start() |
| method_patch.return_value = iter([TIDepStatus("mock_" + class_name, True, "mock")]) |
| patch_dict[class_name] = (dep_patch, method_patch) |
| |
| for class_name, (dep_patch, method_patch) in patch_dict.items(): |
| method_patch.return_value = iter([TIDepStatus("mock_" + class_name, False, "mock")]) |
| ti.run() |
| assert ti.state == State.QUEUED |
| dep_patch.return_value = TIDepStatus("mock_" + class_name, True, "mock") |
| |
| for dep_patch, _ in patch_dict.values(): |
| dep_patch.stop() |
| |
| def test_mark_non_runnable_task_as_success(self, create_task_instance): |
| """ |
| test that running task with mark_success param update task state |
| as SUCCESS without running task despite it fails dependency checks. |
| """ |
| non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop() |
| ti = create_task_instance( |
| dag_id="test_mark_non_runnable_task_as_success", |
| task_id="test_mark_non_runnable_task_as_success_op", |
| state=non_runnable_state, |
| ) |
| ti.run(mark_success=True) |
| assert ti.state == State.SUCCESS |
| |
| @pytest.mark.usefixtures("test_pool") |
| def test_run_pooling_task(self, create_task_instance): |
| """ |
| test that running a task in an existing pool update task state as SUCCESS. |
| """ |
| ti = create_task_instance( |
| dag_id="test_run_pooling_task", |
| task_id="test_run_pooling_task_op", |
| pool="test_pool", |
| ) |
| |
| ti.run() |
| |
| assert ti.state == State.SUCCESS |
| |
| @pytest.mark.usefixtures("test_pool") |
| def test_pool_slots_property(self): |
| """ |
| test that try to create a task with pool_slots less than 1 |
| """ |
| |
| dag = DAG(dag_id="test_run_pooling_task") |
| with pytest.raises(ValueError, match="pool slots .* cannot be less than 1"): |
| EmptyOperator( |
| task_id="test_run_pooling_task_op", |
| dag=dag, |
| pool="test_pool", |
| pool_slots=0, |
| ) |
| |
| @provide_session |
| def test_ti_updates_with_task(self, create_task_instance, session=None): |
| """ |
| test that updating the executor_config propagates to the TaskInstance DB |
| """ |
| ti = create_task_instance( |
| dag_id="test_run_pooling_task", |
| task_id="test_run_pooling_task_op", |
| executor_config={"foo": "bar"}, |
| ) |
| dag = ti.task.dag |
| |
| ti.run(session=session) |
| tis = dag.get_task_instances() |
| assert {"foo": "bar"} == tis[0].executor_config |
| task2 = EmptyOperator( |
| task_id="test_run_pooling_task_op2", |
| executor_config={"bar": "baz"}, |
| start_date=timezone.datetime(2016, 2, 1, 0, 0, 0), |
| dag=dag, |
| ) |
| |
| ti2 = TI(task=task2, run_id=ti.run_id) |
| session.add(ti2) |
| session.flush() |
| |
| ti2.run(session=session) |
| # Ensure it's reloaded |
| ti2.executor_config = None |
| ti2.refresh_from_db(session) |
| assert {"bar": "baz"} == ti2.executor_config |
| session.rollback() |
| |
| def test_run_pooling_task_with_mark_success(self, create_task_instance): |
| """ |
| test that running task in an existing pool with mark_success param |
| update task state as SUCCESS without running task |
| despite it fails dependency checks. |
| """ |
| ti = create_task_instance( |
| dag_id="test_run_pooling_task_with_mark_success", |
| task_id="test_run_pooling_task_with_mark_success_op", |
| ) |
| |
| ti.run(mark_success=True) |
| assert ti.state == State.SUCCESS |
| |
| def test_run_pooling_task_with_skip(self, dag_maker): |
| """ |
| test that running task which returns AirflowSkipOperator will end |
| up in a SKIPPED state. |
| """ |
| |
| def raise_skip_exception(): |
| raise AirflowSkipException |
| |
| with dag_maker(dag_id="test_run_pooling_task_with_skip"): |
| task = PythonOperator( |
| task_id="test_run_pooling_task_with_skip", |
| python_callable=raise_skip_exception, |
| ) |
| |
| dr = dag_maker.create_dagrun(execution_date=timezone.utcnow()) |
| ti = dr.task_instances[0] |
| ti.task = task |
| ti.run() |
| assert State.SKIPPED == ti.state |
| |
| def test_task_sigterm_calls_on_failure_callback(self, dag_maker, caplog): |
| """ |
| Test that ensures that tasks call on_failure_callback when they receive sigterm |
| """ |
| |
| def task_function(ti): |
| os.kill(ti.pid, signal.SIGTERM) |
| |
| with dag_maker(): |
| task_ = PythonOperator( |
| task_id="test_on_failure", |
| python_callable=task_function, |
| on_failure_callback=lambda context: context["ti"].log.info("on_failure_callback called"), |
| ) |
| |
| dr = dag_maker.create_dagrun() |
| ti = dr.task_instances[0] |
| ti.task = task_ |
| with pytest.raises(AirflowTaskTerminated): |
| ti.run() |
| assert "on_failure_callback called" in caplog.text |
| |
| def test_task_sigterm_works_with_retries(self, dag_maker): |
| """ |
| Test that ensures that tasks are retried when they receive sigterm |
| """ |
| |
| def task_function(ti): |
| os.kill(ti.pid, signal.SIGTERM) |
| |
| with dag_maker("test_mark_failure_2"): |
| task = PythonOperator( |
| task_id="test_on_failure", |
| python_callable=task_function, |
| retries=1, |
| retry_delay=datetime.timedelta(seconds=2), |
| ) |
| |
| dr = dag_maker.create_dagrun() |
| ti = dr.task_instances[0] |
| ti.task = task |
| with pytest.raises(AirflowTaskTerminated): |
| ti.run() |
| ti.refresh_from_db() |
| assert ti.state == State.UP_FOR_RETRY |
| |
| @pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED, State.SKIPPED]) |
| def test_task_sigterm_doesnt_change_state_of_finished_tasks(self, state, dag_maker): |
| session = settings.Session() |
| |
| def task_function(ti): |
| ti.state = state |
| session.merge(ti) |
| session.commit() |
| raise AirflowException() |
| |
| with dag_maker("test_mark_failure_2"): |
| task = PythonOperator( |
| task_id="test_on_failure", |
| python_callable=task_function, |
| ) |
| |
| dr = dag_maker.create_dagrun() |
| ti = dr.task_instances[0] |
| ti.task = task |
| |
| ti.run() |
| ti.refresh_from_db() |
| ti.state == state |
| |
| @pytest.mark.parametrize( |
| "state, exception, retries", |
| [ |
| (State.FAILED, AirflowException, 0), |
| (State.SKIPPED, AirflowSkipException, 0), |
| (State.SUCCESS, None, 0), |
| (State.UP_FOR_RESCHEDULE, AirflowRescheduleException(timezone.utcnow()), 0), |
| (State.UP_FOR_RETRY, AirflowException, 1), |
| ], |
| ) |
| def test_task_wipes_next_fields(self, session, dag_maker, state, exception, retries): |
| """ |
| Test that ensures that tasks wipe their next_method and next_kwargs |
| when the TI enters one of the configured states. |
| """ |
| |
| def _raise_if_exception(): |
| if exception: |
| raise exception |
| |
| with dag_maker("test_deferred_method_clear"): |
| task = PythonOperator( |
| task_id="test_deferred_method_clear_task", |
| python_callable=_raise_if_exception, |
| retries=retries, |
| retry_delay=datetime.timedelta(seconds=2), |
| ) |
| |
| dr = dag_maker.create_dagrun() |
| ti = dr.task_instances[0] |
| ti.next_method = "execute" |
| ti.next_kwargs = {} |
| session.merge(ti) |
| session.commit() |
| |
| ti.task = task |
| if state in [State.FAILED, State.UP_FOR_RETRY]: |
| with pytest.raises(exception): |
| ti.run() |
| else: |
| ti.run() |
| ti.refresh_from_db() |
| |
| assert ti.next_method is None |
| assert ti.next_kwargs is None |
| assert ti.state == state |
| |
| def test_retry_delay(self, dag_maker, time_machine): |
| """ |
| Test that retry delays are respected |
| """ |
| time_machine.move_to("2021-09-19 04:56:35", tick=False) |
| with dag_maker(dag_id="test_retry_handling"): |
| task = BashOperator( |
| task_id="test_retry_handling_op", |
| bash_command="exit 1", |
| retries=1, |
| retry_delay=datetime.timedelta(seconds=3), |
| ) |
| |
| def run_with_error(ti): |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| |
| assert ti.try_number == 1 |
| # first run -- up for retry |
| run_with_error(ti) |
| assert ti.state == State.UP_FOR_RETRY |
| assert ti.try_number == 2 |
| |
| # second run -- still up for retry because retry_delay hasn't expired |
| time_machine.coordinates.shift(3) |
| run_with_error(ti) |
| assert ti.state == State.UP_FOR_RETRY |
| |
| # third run -- failed |
| time_machine.coordinates.shift(datetime.datetime.resolution) |
| run_with_error(ti) |
| assert ti.state == State.FAILED |
| |
| def test_retry_handling(self, dag_maker): |
| """ |
| Test that task retries are handled properly |
| """ |
| expected_rendered_ti_fields = { |
| "env": None, |
| "bash_command": "echo test_retry_handling; exit 1", |
| "cwd": None, |
| } |
| |
| with dag_maker(dag_id="test_retry_handling") as dag: |
| task = BashOperator( |
| task_id="test_retry_handling_op", |
| bash_command="echo {{dag.dag_id}}; exit 1", |
| retries=1, |
| retry_delay=datetime.timedelta(seconds=0), |
| ) |
| |
| def run_with_error(ti): |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| assert ti.try_number == 1 |
| |
| # first run -- up for retry |
| run_with_error(ti) |
| assert ti.state == State.UP_FOR_RETRY |
| assert ti._try_number == 1 |
| assert ti.try_number == 2 |
| |
| # second run -- fail |
| run_with_error(ti) |
| assert ti.state == State.FAILED |
| assert ti._try_number == 2 |
| assert ti.try_number == 3 |
| |
| # Clear the TI state since you can't run a task with a FAILED state without |
| # clearing it first |
| dag.clear() |
| |
| # third run -- up for retry |
| run_with_error(ti) |
| assert ti.state == State.UP_FOR_RETRY |
| assert ti._try_number == 3 |
| assert ti.try_number == 4 |
| |
| # fourth run -- fail |
| run_with_error(ti) |
| ti.refresh_from_db() |
| assert ti.state == State.FAILED |
| assert ti._try_number == 4 |
| assert ti.try_number == 5 |
| assert RenderedTaskInstanceFields.get_templated_fields(ti) == expected_rendered_ti_fields |
| |
| def test_next_retry_datetime(self, dag_maker): |
| delay = datetime.timedelta(seconds=30) |
| max_delay = datetime.timedelta(minutes=60) |
| |
| with dag_maker(dag_id="fail_dag"): |
| task = BashOperator( |
| task_id="task_with_exp_backoff_and_max_delay", |
| bash_command="exit 1", |
| retries=3, |
| retry_delay=delay, |
| retry_exponential_backoff=True, |
| max_retry_delay=max_delay, |
| ) |
| ti = dag_maker.create_dagrun().task_instances[0] |
| ti.task = task |
| ti.end_date = pendulum.instance(timezone.utcnow()) |
| |
| date = ti.next_retry_datetime() |
| # between 30 * 2^0.5 and 30 * 2^1 (15 and 30) |
| period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15) |
| assert date in period |
| |
| ti.try_number = 3 |
| date = ti.next_retry_datetime() |
| # between 30 * 2^2 and 30 * 2^3 (120 and 240) |
| period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120) |
| assert date in period |
| |
| ti.try_number = 5 |
| date = ti.next_retry_datetime() |
| # between 30 * 2^4 and 30 * 2^5 (480 and 960) |
| period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480) |
| assert date in period |
| |
| ti.try_number = 9 |
| date = ti.next_retry_datetime() |
| assert date == ti.end_date + max_delay |
| |
| ti.try_number = 50 |
| date = ti.next_retry_datetime() |
| assert date == ti.end_date + max_delay |
| |
| @pytest.mark.parametrize("seconds", [0, 0.5, 1]) |
| def test_next_retry_datetime_short_or_zero_intervals(self, dag_maker, seconds): |
| delay = datetime.timedelta(seconds=seconds) |
| max_delay = datetime.timedelta(minutes=60) |
| |
| with dag_maker(dag_id="fail_dag"): |
| task = BashOperator( |
| task_id="task_with_exp_backoff_and_short_or_zero_time_interval", |
| bash_command="exit 1", |
| retries=3, |
| retry_delay=delay, |
| retry_exponential_backoff=True, |
| max_retry_delay=max_delay, |
| ) |
| ti = dag_maker.create_dagrun().task_instances[0] |
| ti.task = task |
| ti.end_date = pendulum.instance(timezone.utcnow()) |
| |
| date = ti.next_retry_datetime() |
| assert date == ti.end_date + datetime.timedelta(seconds=1) |
| |
| def test_reschedule_handling(self, dag_maker, task_reschedules_for_ti): |
| """ |
| Test that task reschedules are handled properly |
| """ |
| # Return values of the python sensor callable, modified during tests |
| done = False |
| fail = False |
| |
| def func(): |
| if fail: |
| raise AirflowException() |
| return done |
| |
| with dag_maker(dag_id="test_reschedule_handling") as dag: |
| task = PythonSensor( |
| task_id="test_reschedule_handling_sensor", |
| poke_interval=0, |
| mode="reschedule", |
| python_callable=func, |
| retries=1, |
| retry_delay=datetime.timedelta(seconds=0), |
| ) |
| |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| assert ti._try_number == 0 |
| assert ti.try_number == 1 |
| |
| def run_ti_and_assert( |
| run_date, |
| expected_start_date, |
| expected_end_date, |
| expected_duration, |
| expected_state, |
| expected_try_number, |
| expected_task_reschedule_count, |
| ): |
| with time_machine.travel(run_date, tick=False): |
| try: |
| ti.run() |
| except AirflowException: |
| if not fail: |
| raise |
| ti.refresh_from_db() |
| assert ti.state == expected_state |
| assert ti._try_number == expected_try_number |
| assert ti.try_number == expected_try_number + 1 |
| assert ti.start_date == expected_start_date |
| assert ti.end_date == expected_end_date |
| assert ti.duration == expected_duration |
| assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count |
| |
| date1 = timezone.utcnow() |
| date2 = date1 + datetime.timedelta(minutes=1) |
| date3 = date2 + datetime.timedelta(minutes=1) |
| date4 = date3 + datetime.timedelta(minutes=1) |
| |
| # Run with multiple reschedules. |
| # During reschedule the try number remains the same, but each reschedule is recorded. |
| # The start date is expected to remain the initial date, hence the duration increases. |
| # When finished the try number is incremented and there is no reschedule expected |
| # for this try. |
| |
| done, fail = False, False |
| run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) |
| |
| done, fail = False, False |
| run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2) |
| |
| done, fail = False, False |
| run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3) |
| |
| done, fail = True, False |
| run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0) |
| |
| # Clear the task instance. |
| dag.clear() |
| ti.refresh_from_db() |
| assert ti.state == State.NONE |
| assert ti._try_number == 1 |
| |
| # Run again after clearing with reschedules and a retry. |
| # The retry increments the try number, and for that try no reschedule is expected. |
| # After the retry the start date is reset, hence the duration is also reset. |
| |
| done, fail = False, False |
| run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1) |
| |
| done, fail = False, True |
| run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0) |
| |
| done, fail = False, False |
| run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1) |
| |
| done, fail = True, False |
| run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) |
| |
| def test_mapped_reschedule_handling(self, dag_maker, task_reschedules_for_ti): |
| """ |
| Test that mapped task reschedules are handled properly |
| """ |
| # Return values of the python sensor callable, modified during tests |
| done = False |
| fail = False |
| |
| def func(): |
| if fail: |
| raise AirflowException() |
| return done |
| |
| with dag_maker(dag_id="test_reschedule_handling") as dag: |
| task = PythonSensor.partial( |
| task_id="test_reschedule_handling_sensor", |
| mode="reschedule", |
| python_callable=func, |
| retries=1, |
| retry_delay=datetime.timedelta(seconds=0), |
| ).expand(poke_interval=[0]) |
| |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| |
| ti.task = task |
| assert ti._try_number == 0 |
| assert ti.try_number == 1 |
| |
| def run_ti_and_assert( |
| run_date, |
| expected_start_date, |
| expected_end_date, |
| expected_duration, |
| expected_state, |
| expected_try_number, |
| expected_task_reschedule_count, |
| ): |
| ti.refresh_from_task(task) |
| with time_machine.travel(run_date, tick=False): |
| try: |
| ti.run() |
| except AirflowException: |
| if not fail: |
| raise |
| ti.refresh_from_db() |
| assert ti.state == expected_state |
| assert ti._try_number == expected_try_number |
| assert ti.try_number == expected_try_number + 1 |
| assert ti.start_date == expected_start_date |
| assert ti.end_date == expected_end_date |
| assert ti.duration == expected_duration |
| assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count |
| |
| date1 = timezone.utcnow() |
| date2 = date1 + datetime.timedelta(minutes=1) |
| date3 = date2 + datetime.timedelta(minutes=1) |
| date4 = date3 + datetime.timedelta(minutes=1) |
| |
| # Run with multiple reschedules. |
| # During reschedule the try number remains the same, but each reschedule is recorded. |
| # The start date is expected to remain the initial date, hence the duration increases. |
| # When finished the try number is incremented and there is no reschedule expected |
| # for this try. |
| |
| done, fail = False, False |
| run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) |
| |
| done, fail = False, False |
| run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2) |
| |
| done, fail = False, False |
| run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3) |
| |
| done, fail = True, False |
| run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0) |
| |
| # Clear the task instance. |
| dag.clear() |
| ti.refresh_from_db() |
| assert ti.state == State.NONE |
| assert ti._try_number == 1 |
| |
| # Run again after clearing with reschedules and a retry. |
| # The retry increments the try number, and for that try no reschedule is expected. |
| # After the retry the start date is reset, hence the duration is also reset. |
| |
| done, fail = False, False |
| run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1) |
| |
| done, fail = False, True |
| run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0) |
| |
| done, fail = False, False |
| run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1) |
| |
| done, fail = True, False |
| run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) |
| |
| @pytest.mark.usefixtures("test_pool") |
| def test_mapped_task_reschedule_handling_clear_reschedules(self, dag_maker, task_reschedules_for_ti): |
| """ |
| Test that mapped task reschedules clearing are handled properly |
| """ |
| # Return values of the python sensor callable, modified during tests |
| done = False |
| fail = False |
| |
| def func(): |
| if fail: |
| raise AirflowException() |
| return done |
| |
| with dag_maker(dag_id="test_reschedule_handling") as dag: |
| task = PythonSensor.partial( |
| task_id="test_reschedule_handling_sensor", |
| mode="reschedule", |
| python_callable=func, |
| retries=1, |
| retry_delay=datetime.timedelta(seconds=0), |
| pool="test_pool", |
| ).expand(poke_interval=[0]) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| assert ti._try_number == 0 |
| assert ti.try_number == 1 |
| |
| def run_ti_and_assert( |
| run_date, |
| expected_start_date, |
| expected_end_date, |
| expected_duration, |
| expected_state, |
| expected_try_number, |
| expected_task_reschedule_count, |
| ): |
| ti.refresh_from_task(task) |
| with time_machine.travel(run_date, tick=False): |
| try: |
| ti.run() |
| except AirflowException: |
| if not fail: |
| raise |
| ti.refresh_from_db() |
| assert ti.state == expected_state |
| assert ti._try_number == expected_try_number |
| assert ti.try_number == expected_try_number + 1 |
| assert ti.start_date == expected_start_date |
| assert ti.end_date == expected_end_date |
| assert ti.duration == expected_duration |
| assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count |
| |
| date1 = timezone.utcnow() |
| |
| done, fail = False, False |
| run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) |
| |
| # Clear the task instance. |
| dag.clear() |
| ti.refresh_from_db() |
| assert ti.state == State.NONE |
| assert ti._try_number == 0 |
| # Check that reschedules for ti have also been cleared. |
| assert not task_reschedules_for_ti(ti) |
| |
| @pytest.mark.usefixtures("test_pool") |
| def test_reschedule_handling_clear_reschedules(self, dag_maker, task_reschedules_for_ti): |
| """ |
| Test that task reschedules clearing are handled properly |
| """ |
| # Return values of the python sensor callable, modified during tests |
| done = False |
| fail = False |
| |
| def func(): |
| if fail: |
| raise AirflowException() |
| return done |
| |
| with dag_maker(dag_id="test_reschedule_handling") as dag: |
| task = PythonSensor( |
| task_id="test_reschedule_handling_sensor", |
| poke_interval=0, |
| mode="reschedule", |
| python_callable=func, |
| retries=1, |
| retry_delay=datetime.timedelta(seconds=0), |
| pool="test_pool", |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| assert ti._try_number == 0 |
| assert ti.try_number == 1 |
| |
| def run_ti_and_assert( |
| run_date, |
| expected_start_date, |
| expected_end_date, |
| expected_duration, |
| expected_state, |
| expected_try_number, |
| expected_task_reschedule_count, |
| ): |
| with time_machine.travel(run_date, tick=False): |
| try: |
| ti.run() |
| except AirflowException: |
| if not fail: |
| raise |
| ti.refresh_from_db() |
| assert ti.state == expected_state |
| assert ti._try_number == expected_try_number |
| assert ti.try_number == expected_try_number + 1 |
| assert ti.start_date == expected_start_date |
| assert ti.end_date == expected_end_date |
| assert ti.duration == expected_duration |
| assert len(task_reschedules_for_ti(ti)) == expected_task_reschedule_count |
| |
| date1 = timezone.utcnow() |
| |
| done, fail = False, False |
| run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) |
| |
| # Clear the task instance. |
| dag.clear() |
| ti.refresh_from_db() |
| assert ti.state == State.NONE |
| assert ti._try_number == 0 |
| # Check that reschedules for ti have also been cleared. |
| assert not task_reschedules_for_ti(ti) |
| |
| def test_depends_on_past(self, dag_maker): |
| with dag_maker(dag_id="test_depends_on_past"): |
| task = EmptyOperator( |
| task_id="test_dop_task", |
| depends_on_past=True, |
| ) |
| dag_maker.create_dagrun( |
| state=State.FAILED, |
| run_type=DagRunType.SCHEDULED, |
| ) |
| |
| run_date = task.start_date + datetime.timedelta(days=5) |
| |
| dr = dag_maker.create_dagrun( |
| execution_date=run_date, |
| run_type=DagRunType.SCHEDULED, |
| ) |
| |
| ti = dr.task_instances[0] |
| ti.task = task |
| |
| # depends_on_past prevents the run |
| task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=False) |
| ti.refresh_from_db() |
| assert ti.state is None |
| |
| # ignore first depends_on_past to allow the run |
| task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=True) |
| ti.refresh_from_db() |
| assert ti.state == State.SUCCESS |
| |
| # Parameterized tests to check for the correct firing |
| # of the trigger_rule under various circumstances |
| # Numeric fields are in order: |
| # successes, skipped, failed, upstream_failed, removed, done |
| @pytest.mark.parametrize( |
| "trigger_rule, upstream_setups, upstream_states, flag_upstream_failed, expect_state, expect_passed", |
| [ |
| # |
| # Tests for all_success |
| # |
| ["all_success", 0, _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, None, True], |
| ["all_success", 0, _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, None, False], |
| ["all_success", 0, _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["all_success", 0, _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, State.SKIPPED, False], |
| # |
| # Tests for one_success |
| # |
| ["one_success", 0, _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, None, True], |
| ["one_success", 0, _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, None, True], |
| ["one_success", 0, _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, None, True], |
| ["one_success", 0, _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, None, True], |
| ["one_success", 0, _UpstreamTIStates(0, 5, 0, 0, 0, 5, 0, 0), True, State.SKIPPED, False], |
| ["one_success", 0, _UpstreamTIStates(0, 4, 1, 0, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", 0, _UpstreamTIStates(0, 3, 1, 1, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", 0, _UpstreamTIStates(0, 4, 0, 1, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", 0, _UpstreamTIStates(0, 0, 5, 0, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", 0, _UpstreamTIStates(0, 0, 4, 1, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", 0, _UpstreamTIStates(0, 0, 0, 5, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| # |
| # Tests for all_failed |
| # |
| ["all_failed", 0, _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, State.SKIPPED, False], |
| ["all_failed", 0, _UpstreamTIStates(0, 0, 5, 0, 0, 5, 0, 0), True, None, True], |
| ["all_failed", 0, _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, State.SKIPPED, False], |
| ["all_failed", 0, _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, State.SKIPPED, False], |
| ["all_failed", 0, _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, State.SKIPPED, False], |
| # |
| # Tests for one_failed |
| # |
| ["one_failed", 0, _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, State.SKIPPED, False], |
| ["one_failed", 0, _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, None, False], |
| ["one_failed", 0, _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, None, True], |
| ["one_failed", 0, _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, None, False], |
| ["one_failed", 0, _UpstreamTIStates(2, 3, 0, 0, 0, 5, 0, 0), True, State.SKIPPED, False], |
| # |
| # Tests for done |
| # |
| ["all_done", 0, _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, None, True], |
| ["all_done", 0, _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, None, False], |
| ["all_done", 0, _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, None, False], |
| ["all_done", 0, _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, None, False], |
| # |
| # Tests for all_done_setup_success: no upstream setups -> same as all_done |
| # |
| ["all_done_setup_success", 0, _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, None, True], |
| ["all_done_setup_success", 0, _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, None, False], |
| ["all_done_setup_success", 0, _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, None, False], |
| ["all_done_setup_success", 0, _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, None, False], |
| # |
| # Tests for all_done_setup_success: with upstream setups -> different from all_done |
| # |
| # params: |
| # trigger_rule |
| # upstream_setups |
| # upstream_states |
| # flag_upstream_failed |
| # expect_state |
| # expect_passed |
| # states: success, skipped, failed, upstream_failed, removed, done, success_setup, skipped_setup |
| # all setups succeeded - one |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(6, 0, 0, 0, 0, 6, 1, 0), |
| True, |
| None, |
| True, |
| id="all setups succeeded - one", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(7, 0, 0, 0, 0, 7, 2, 0), |
| True, |
| None, |
| True, |
| id="all setups succeeded - two", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(5, 0, 1, 0, 0, 6, 0, 0), |
| True, |
| State.UPSTREAM_FAILED, |
| False, |
| id="setups failed - one", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(5, 0, 2, 0, 0, 7, 0, 0), |
| True, |
| State.UPSTREAM_FAILED, |
| False, |
| id="setups failed - two", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(5, 1, 0, 0, 0, 6, 0, 1), |
| True, |
| State.SKIPPED, |
| False, |
| id="setups skipped - one", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(5, 2, 0, 0, 0, 7, 0, 2), |
| True, |
| State.SKIPPED, |
| False, |
| id="setups skipped - two", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(5, 1, 1, 0, 0, 7, 0, 1), |
| True, |
| State.UPSTREAM_FAILED, |
| False, |
| id="one setup failed one setup skipped", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0), |
| True, |
| (True, None), # is_teardown=True, expect_state=None |
| True, |
| id="is teardown one setup failed one setup success", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0), |
| True, |
| (False, "upstream_failed"), # is_teardown=False, expect_state="upstream_failed" |
| True, |
| id="not teardown one setup failed one setup success", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1), |
| True, |
| (True, None), # is_teardown=True, expect_state=None |
| True, |
| id="is teardown one setup success one setup skipped", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 2, |
| _UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1), |
| True, |
| (False, "skipped"), # is_teardown=False, expect_state="skipped" |
| True, |
| id="not teardown one setup success one setup skipped", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(3, 0, 0, 0, 0, 3, 1, 0), |
| True, |
| None, |
| False, |
| id="not all done", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0), |
| True, |
| (True, None), # is_teardown=True, expect_state=None |
| False, |
| id="is teardown not all done one failed", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0), |
| True, |
| (False, "upstream_failed"), # is_teardown=False, expect_state="upstream_failed" |
| False, |
| id="not teardown not all done one failed", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0), |
| True, |
| (True, None), # is_teardown=True, expect_state=None |
| False, |
| id="not all done one skipped", |
| ), |
| pytest.param( |
| "all_done_setup_success", |
| 1, |
| _UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0), |
| True, |
| (False, "skipped"), # is_teardown=False, expect_state="skipped' |
| False, |
| id="not all done one skipped", |
| ), |
| ], |
| ) |
| def test_check_task_dependencies( |
| self, |
| monkeypatch, |
| dag_maker, |
| trigger_rule: str, |
| upstream_setups: int, |
| upstream_states: _UpstreamTIStates, |
| flag_upstream_failed: bool, |
| expect_state: State, |
| expect_passed: bool, |
| ): |
| # this allows us to change the expected state depending on whether the |
| # task is a teardown |
| set_teardown = False |
| if isinstance(expect_state, tuple): |
| set_teardown, expect_state = expect_state |
| assert isinstance(set_teardown, bool) |
| |
| monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) |
| |
| # sanity checks |
| s = upstream_states |
| assert s.skipped >= s.skipped_setup |
| assert s.success >= s.success_setup |
| assert s.done == s.failed + s.success + s.removed + s.upstream_failed + s.skipped |
| |
| with dag_maker() as dag: |
| downstream = EmptyOperator(task_id="downstream", trigger_rule=trigger_rule) |
| if set_teardown: |
| downstream.as_teardown() |
| for i in range(5): |
| task = EmptyOperator(task_id=f"work_{i}", dag=dag) |
| task.set_downstream(downstream) |
| for i in range(upstream_setups): |
| task = EmptyOperator(task_id=f"setup_{i}", dag=dag).as_setup() |
| task.set_downstream(downstream) |
| assert task.start_date is not None |
| run_date = task.start_date + datetime.timedelta(days=5) |
| |
| ti = dag_maker.create_dagrun(execution_date=run_date).get_task_instance(downstream.task_id) |
| ti.task = downstream |
| |
| dep_results = TriggerRuleDep()._evaluate_trigger_rule( |
| ti=ti, |
| dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), |
| session=dag_maker.session, |
| ) |
| completed = all(dep.passed for dep in dep_results) |
| |
| assert completed == expect_passed |
| assert ti.state == expect_state |
| |
| # Parameterized tests to check for the correct firing |
| # of the trigger_rule under various circumstances of mapped task |
| # Numeric fields are in order: |
| # successes, skipped, failed, upstream_failed, done,removed |
| @pytest.mark.parametrize( |
| "trigger_rule, upstream_states, flag_upstream_failed, expect_state, expect_completed", |
| [ |
| # |
| # Tests for all_success |
| # |
| ["all_success", _UpstreamTIStates(5, 0, 0, 0, 0, 0, 0, 0), True, None, True], |
| ["all_success", _UpstreamTIStates(2, 0, 0, 0, 0, 0, 0, 0), True, None, False], |
| ["all_success", _UpstreamTIStates(2, 0, 1, 0, 0, 0, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["all_success", _UpstreamTIStates(2, 1, 0, 0, 0, 0, 0, 0), True, State.SKIPPED, False], |
| # ti.map_index >= success |
| ["all_success", _UpstreamTIStates(3, 0, 0, 0, 2, 0, 0, 0), True, State.REMOVED, True], |
| # |
| # Tests for one_success |
| # |
| ["one_success", _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, None, True], |
| ["one_success", _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, None, True], |
| ["one_success", _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, None, True], |
| ["one_success", _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, None, True], |
| ["one_success", _UpstreamTIStates(0, 5, 0, 0, 0, 5, 0, 0), True, State.SKIPPED, False], |
| ["one_success", _UpstreamTIStates(0, 4, 1, 0, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", _UpstreamTIStates(0, 3, 1, 1, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", _UpstreamTIStates(0, 4, 0, 1, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", _UpstreamTIStates(0, 0, 5, 0, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", _UpstreamTIStates(0, 0, 4, 1, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| ["one_success", _UpstreamTIStates(0, 0, 0, 5, 0, 5, 0, 0), True, State.UPSTREAM_FAILED, False], |
| # |
| # Tests for all_failed |
| # |
| ["all_failed", _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, State.SKIPPED, False], |
| ["all_failed", _UpstreamTIStates(0, 0, 5, 0, 0, 5, 0, 0), True, None, True], |
| ["all_failed", _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, State.SKIPPED, False], |
| ["all_failed", _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, State.SKIPPED, False], |
| ["all_failed", _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, State.SKIPPED, False], |
| [ |
| "all_failed", |
| _UpstreamTIStates(2, 1, 0, 0, 1, 4, 0, 0), |
| True, |
| State.SKIPPED, |
| False, |
| ], # One removed |
| # |
| # Tests for one_failed |
| # |
| ["one_failed", _UpstreamTIStates(5, 0, 0, 0, 0, 0, 0, 0), True, None, False], |
| ["one_failed", _UpstreamTIStates(2, 0, 0, 0, 0, 0, 0, 0), True, None, False], |
| ["one_failed", _UpstreamTIStates(2, 0, 1, 0, 0, 0, 0, 0), True, None, True], |
| ["one_failed", _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, None, False], |
| ["one_failed", _UpstreamTIStates(2, 3, 0, 0, 0, 5, 0, 0), True, State.SKIPPED, False], |
| [ |
| "one_failed", |
| _UpstreamTIStates(2, 2, 0, 0, 1, 5, 0, 0), |
| True, |
| State.SKIPPED, |
| False, |
| ], # One removed |
| # |
| # Tests for done |
| # |
| ["all_done", _UpstreamTIStates(5, 0, 0, 0, 0, 5, 0, 0), True, None, True], |
| ["all_done", _UpstreamTIStates(2, 0, 0, 0, 0, 2, 0, 0), True, None, False], |
| ["all_done", _UpstreamTIStates(2, 0, 1, 0, 0, 3, 0, 0), True, None, False], |
| ["all_done", _UpstreamTIStates(2, 1, 0, 0, 0, 3, 0, 0), True, None, False], |
| ], |
| ) |
| def test_check_task_dependencies_for_mapped( |
| self, |
| monkeypatch, |
| dag_maker, |
| session, |
| trigger_rule: str, |
| upstream_states: _UpstreamTIStates, |
| flag_upstream_failed: bool, |
| expect_state: State, |
| expect_completed: bool, |
| ): |
| from airflow.decorators import task |
| |
| @task |
| def do_something(i): |
| return 1 |
| |
| @task(trigger_rule=trigger_rule) |
| def do_something_else(i): |
| return 1 |
| |
| with dag_maker(dag_id="test_dag", session=session): |
| nums = do_something.expand(i=[i + 1 for i in range(5)]) |
| do_something_else.expand(i=nums) |
| |
| dr = dag_maker.create_dagrun() |
| |
| monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) |
| ti = dr.get_task_instance("do_something_else", session=session) |
| ti.map_index = 0 |
| for map_index in range(1, 5): |
| ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) |
| session.add(ti) |
| ti.dag_run = dr |
| session.flush() |
| downstream = ti.task |
| ti = dr.get_task_instance(task_id="do_something_else", map_index=3, session=session) |
| ti.task = downstream |
| dep_results = TriggerRuleDep()._evaluate_trigger_rule( |
| ti=ti, |
| dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), |
| session=session, |
| ) |
| completed = all(dep.passed for dep in dep_results) |
| |
| assert completed == expect_completed |
| assert ti.state == expect_state |
| |
| def test_respects_prev_dagrun_dep(self, create_task_instance): |
| ti = create_task_instance() |
| failing_status = [TIDepStatus("test fail status name", False, "test fail reason")] |
| passing_status = [TIDepStatus("test pass status name", True, "test passing reason")] |
| with patch( |
| "airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses", return_value=failing_status |
| ): |
| assert not ti.are_dependencies_met() |
| with patch( |
| "airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses", return_value=passing_status |
| ): |
| assert ti.are_dependencies_met() |
| |
| @pytest.mark.parametrize( |
| "downstream_ti_state, expected_are_dependents_done", |
| [ |
| (State.SUCCESS, True), |
| (State.SKIPPED, True), |
| (State.RUNNING, False), |
| (State.FAILED, False), |
| (State.NONE, False), |
| ], |
| ) |
| @provide_session |
| def test_are_dependents_done( |
| self, downstream_ti_state, expected_are_dependents_done, create_task_instance, session=None |
| ): |
| ti = create_task_instance(session=session) |
| dag = ti.task.dag |
| downstream_task = EmptyOperator(task_id="downstream_task", dag=dag) |
| ti.task >> downstream_task |
| |
| downstream_ti = TI(downstream_task, run_id=ti.run_id) |
| |
| downstream_ti.set_state(downstream_ti_state, session) |
| session.flush() |
| assert ti.are_dependents_done(session) == expected_are_dependents_done |
| |
| def test_xcom_pull(self, dag_maker): |
| """Test xcom_pull, using different filtering methods.""" |
| with dag_maker(dag_id="test_xcom") as dag: |
| task_1 = EmptyOperator(task_id="test_xcom_1") |
| task_2 = EmptyOperator(task_id="test_xcom_2") |
| |
| dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) |
| ti1 = dagrun.get_task_instance(task_1.task_id) |
| |
| # Push a value |
| ti1.xcom_push(key="foo", value="bar") |
| |
| # Push another value with the same key (but by a different task) |
| XCom.set( |
| key="foo", |
| value="baz", |
| task_id=task_2.task_id, |
| dag_id=dag.dag_id, |
| execution_date=dagrun.execution_date, |
| ) |
| |
| # Pull with no arguments |
| result = ti1.xcom_pull() |
| assert result is None |
| # Pull the value pushed most recently by any task. |
| result = ti1.xcom_pull(key="foo") |
| assert result in "baz" |
| # Pull the value pushed by the first task |
| result = ti1.xcom_pull(task_ids="test_xcom_1", key="foo") |
| assert result == "bar" |
| # Pull the value pushed by the second task |
| result = ti1.xcom_pull(task_ids="test_xcom_2", key="foo") |
| assert result == "baz" |
| # Pull the values pushed by both tasks & Verify Order of task_ids pass & values returned |
| result = ti1.xcom_pull(task_ids=["test_xcom_1", "test_xcom_2"], key="foo") |
| assert result == ["bar", "baz"] |
| |
| def test_xcom_pull_mapped(self, dag_maker, session): |
| with dag_maker(dag_id="test_xcom", session=session): |
| # Use the private _expand() method to avoid the empty kwargs check. |
| # We don't care about how the operator runs here, only its presence. |
| task_1 = EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, strict=False) |
| EmptyOperator(task_id="task_2") |
| |
| dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) |
| |
| ti_1_0 = dagrun.get_task_instance("task_1", session=session) |
| ti_1_0.map_index = 0 |
| ti_1_1 = session.merge(TI(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) |
| session.flush() |
| |
| ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session) |
| ti_1_1.xcom_push(key=XCOM_RETURN_KEY, value="b", session=session) |
| |
| ti_2 = dagrun.get_task_instance("task_2", session=session) |
| |
| assert set(ti_2.xcom_pull(["task_1"], session=session)) == {"a", "b"} # Ordering not guaranteed. |
| assert ti_2.xcom_pull(["task_1"], map_indexes=0, session=session) == ["a"] |
| |
| assert ti_2.xcom_pull(map_indexes=[0, 1], session=session) == ["a", "b"] |
| assert ti_2.xcom_pull("task_1", map_indexes=[1, 0], session=session) == ["b", "a"] |
| assert ti_2.xcom_pull(["task_1"], map_indexes=[0, 1], session=session) == ["a", "b"] |
| |
| assert ti_2.xcom_pull("task_1", map_indexes=1, session=session) == "b" |
| assert list(ti_2.xcom_pull("task_1", session=session)) == ["a", "b"] |
| |
| def test_xcom_pull_after_success(self, create_task_instance): |
| """ |
| tests xcom set/clear relative to a task in a 'success' rerun scenario |
| """ |
| key = "xcom_key" |
| value = "xcom_value" |
| |
| ti = create_task_instance( |
| dag_id="test_xcom", |
| schedule="@monthly", |
| task_id="test_xcom", |
| pool="test_xcom", |
| ) |
| |
| ti.run(mark_success=True) |
| ti.xcom_push(key=key, value=value) |
| assert ti.xcom_pull(task_ids="test_xcom", key=key) == value |
| ti.run() |
| # Check that we do not clear Xcom until the task is certain to execute |
| assert ti.xcom_pull(task_ids="test_xcom", key=key) == value |
| # Xcom shouldn't be cleared if the task doesn't execute, even if dependencies are ignored |
| ti.run(ignore_all_deps=True, mark_success=True) |
| assert ti.xcom_pull(task_ids="test_xcom", key=key) == value |
| # Xcom IS finally cleared once task has executed |
| ti.run(ignore_all_deps=True) |
| assert ti.xcom_pull(task_ids="test_xcom", key=key) is None |
| |
| def test_xcom_pull_after_deferral(self, create_task_instance, session): |
| """ |
| tests xcom will not clear before a task runs its next method after deferral. |
| """ |
| |
| key = "xcom_key" |
| value = "xcom_value" |
| |
| ti = create_task_instance( |
| dag_id="test_xcom", |
| schedule="@monthly", |
| task_id="test_xcom", |
| pool="test_xcom", |
| ) |
| |
| ti.run(mark_success=True) |
| ti.xcom_push(key=key, value=value) |
| |
| ti.next_method = "execute" |
| session.merge(ti) |
| session.commit() |
| |
| ti.run(ignore_all_deps=True) |
| assert ti.xcom_pull(task_ids="test_xcom", key=key) == value |
| |
| def test_xcom_pull_different_execution_date(self, create_task_instance): |
| """ |
| tests xcom fetch behavior with different execution dates, using |
| both xcom_pull with "include_prior_dates" and without |
| """ |
| key = "xcom_key" |
| value = "xcom_value" |
| |
| ti = create_task_instance( |
| dag_id="test_xcom", |
| schedule="@monthly", |
| task_id="test_xcom", |
| pool="test_xcom", |
| ) |
| exec_date = ti.dag_run.execution_date |
| |
| ti.run(mark_success=True) |
| ti.xcom_push(key=key, value=value) |
| assert ti.xcom_pull(task_ids="test_xcom", key=key) == value |
| ti.run() |
| exec_date += datetime.timedelta(days=1) |
| dr = ti.task.dag.create_dagrun(run_id="test2", execution_date=exec_date, state=None) |
| ti = TI(task=ti.task, run_id=dr.run_id) |
| ti.run() |
| # We have set a new execution date (and did not pass in |
| # 'include_prior_dates'which means this task should now have a cleared |
| # xcom value |
| assert ti.xcom_pull(task_ids="test_xcom", key=key) is None |
| # We *should* get a value using 'include_prior_dates' |
| assert ti.xcom_pull(task_ids="test_xcom", key=key, include_prior_dates=True) == value |
| |
| def test_xcom_push_flag(self, dag_maker): |
| """ |
| Tests the option for Operators to push XComs |
| """ |
| value = "hello" |
| task_id = "test_no_xcom_push" |
| |
| with dag_maker(dag_id="test_xcom"): |
| # nothing saved to XCom |
| task = PythonOperator( |
| task_id=task_id, |
| python_callable=lambda: value, |
| do_xcom_push=False, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| ti.run() |
| assert ti.xcom_pull(task_ids=task_id, key=XCOM_RETURN_KEY) is None |
| |
| def test_xcom_without_multiple_outputs(self, dag_maker): |
| """ |
| Tests the option for Operators to push XComs without multiple outputs |
| """ |
| value = {"key1": "value1", "key2": "value2"} |
| task_id = "test_xcom_push_without_multiple_outputs" |
| |
| with dag_maker(dag_id="test_xcom"): |
| task = PythonOperator( |
| task_id=task_id, |
| python_callable=lambda: value, |
| do_xcom_push=True, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| ti.run() |
| assert ti.xcom_pull(task_ids=task_id, key=XCOM_RETURN_KEY) == value |
| |
| def test_xcom_with_multiple_outputs(self, dag_maker): |
| """ |
| Tests the option for Operators to push XComs with multiple outputs |
| """ |
| value = {"key1": "value1", "key2": "value2"} |
| task_id = "test_xcom_push_with_multiple_outputs" |
| |
| with dag_maker(dag_id="test_xcom"): |
| task = PythonOperator( |
| task_id=task_id, |
| python_callable=lambda: value, |
| do_xcom_push=True, |
| multiple_outputs=True, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| ti.run() |
| assert ti.xcom_pull(task_ids=task_id, key=XCOM_RETURN_KEY) == value |
| assert ti.xcom_pull(task_ids=task_id, key="key1") == "value1" |
| assert ti.xcom_pull(task_ids=task_id, key="key2") == "value2" |
| |
| def test_xcom_with_multiple_outputs_and_no_mapping_result(self, dag_maker): |
| """ |
| Tests the option for Operators to push XComs with multiple outputs and no mapping result |
| """ |
| value = "value" |
| task_id = "test_xcom_push_with_multiple_outputs" |
| |
| with dag_maker(dag_id="test_xcom"): |
| task = PythonOperator( |
| task_id=task_id, |
| python_callable=lambda: value, |
| do_xcom_push=True, |
| multiple_outputs=True, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| with pytest.raises(AirflowException) as ctx: |
| ti.run() |
| assert f"Returned output was type {type(value)} expected dictionary for multiple_outputs" in str( |
| ctx.value |
| ) |
| |
| def test_post_execute_hook(self, dag_maker): |
| """ |
| Test that post_execute hook is called with the Operator's result. |
| The result ('error') will cause an error to be raised and trapped. |
| """ |
| |
| class TestError(Exception): |
| pass |
| |
| class TestOperator(PythonOperator): |
| def post_execute(self, context, result=None): |
| if result == "error": |
| raise TestError("expected error.") |
| |
| with dag_maker(dag_id="test_post_execute_dag"): |
| task = TestOperator( |
| task_id="test_operator", |
| python_callable=lambda: "error", |
| ) |
| ti = dag_maker.create_dagrun(execution_date=DEFAULT_DATE).task_instances[0] |
| ti.task = task |
| with pytest.raises(TestError): |
| ti.run() |
| |
| def test_check_and_change_state_before_execution(self, create_task_instance): |
| expected_external_executor_id = "banana" |
| ti = create_task_instance( |
| dag_id="test_check_and_change_state_before_execution", |
| external_executor_id=expected_external_executor_id, |
| ) |
| SerializedDagModel.write_dag(ti.task.dag) |
| |
| serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag |
| ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) |
| |
| assert ti_from_deserialized_task._try_number == 0 |
| assert ti_from_deserialized_task.check_and_change_state_before_execution() |
| # State should be running, and try_number column should be incremented |
| assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id |
| assert ti_from_deserialized_task.state == State.RUNNING |
| assert ti_from_deserialized_task._try_number == 1 |
| |
| def test_check_and_change_state_before_execution_provided_id_overrides(self, create_task_instance): |
| expected_external_executor_id = "banana" |
| ti = create_task_instance( |
| dag_id="test_check_and_change_state_before_execution", |
| external_executor_id="apple", |
| ) |
| assert ti.external_executor_id == "apple" |
| SerializedDagModel.write_dag(ti.task.dag) |
| |
| serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag |
| ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) |
| |
| assert ti_from_deserialized_task._try_number == 0 |
| assert ti_from_deserialized_task.check_and_change_state_before_execution( |
| external_executor_id=expected_external_executor_id |
| ) |
| # State should be running, and try_number column should be incremented |
| assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id |
| assert ti_from_deserialized_task.state == State.RUNNING |
| assert ti_from_deserialized_task._try_number == 1 |
| |
| def test_check_and_change_state_before_execution_with_exec_id(self, create_task_instance): |
| expected_external_executor_id = "minions" |
| ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") |
| assert ti.external_executor_id is None |
| SerializedDagModel.write_dag(ti.task.dag) |
| |
| serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag |
| ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) |
| |
| assert ti_from_deserialized_task._try_number == 0 |
| assert ti_from_deserialized_task.check_and_change_state_before_execution( |
| external_executor_id=expected_external_executor_id |
| ) |
| # State should be running, and try_number column should be incremented |
| assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id |
| assert ti_from_deserialized_task.state == State.RUNNING |
| assert ti_from_deserialized_task._try_number == 1 |
| |
| def test_check_and_change_state_before_execution_dep_not_met(self, create_task_instance): |
| ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") |
| task2 = EmptyOperator(task_id="task2", dag=ti.task.dag, start_date=DEFAULT_DATE) |
| ti.task >> task2 |
| SerializedDagModel.write_dag(ti.task.dag) |
| |
| serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag |
| ti2 = TI(task=serialized_dag.get_task(task2.task_id), run_id=ti.run_id) |
| assert not ti2.check_and_change_state_before_execution() |
| |
| def test_check_and_change_state_before_execution_dep_not_met_already_running(self, create_task_instance): |
| """return False if the task instance state is running""" |
| ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") |
| with create_session() as _: |
| ti.state = State.RUNNING |
| |
| SerializedDagModel.write_dag(ti.task.dag) |
| |
| serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag |
| ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) |
| |
| assert not ti_from_deserialized_task.check_and_change_state_before_execution() |
| assert ti_from_deserialized_task.state == State.RUNNING |
| assert ti_from_deserialized_task.external_executor_id is None |
| |
| def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( |
| self, create_task_instance |
| ): |
| """return False if the task instance state is failed""" |
| ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") |
| with create_session() as _: |
| ti.state = State.FAILED |
| |
| SerializedDagModel.write_dag(ti.task.dag) |
| |
| serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag |
| ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) |
| |
| assert not ti_from_deserialized_task.check_and_change_state_before_execution() |
| assert ti_from_deserialized_task.state == State.FAILED |
| |
| def test_try_number(self, create_task_instance): |
| """ |
| Test the try_number accessor behaves in various running states |
| """ |
| ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") |
| assert 1 == ti.try_number |
| ti.try_number = 2 |
| ti.state = State.RUNNING |
| assert 2 == ti.try_number |
| ti.state = State.SUCCESS |
| assert 3 == ti.try_number |
| |
| def test_get_num_running_task_instances(self, create_task_instance): |
| session = settings.Session() |
| |
| ti1 = create_task_instance( |
| dag_id="test_get_num_running_task_instances", task_id="task1", session=session |
| ) |
| |
| dr = ti1.task.dag.create_dagrun( |
| execution_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| state=None, |
| run_id="2", |
| session=session, |
| ) |
| assert ti1 in session |
| ti2 = dr.task_instances[0] |
| ti2.task = ti1.task |
| |
| ti3 = create_task_instance( |
| dag_id="test_get_num_running_task_instances_dummy", task_id="task2", session=session |
| ) |
| assert ti3 in session |
| assert ti1 in session |
| |
| ti1.state = State.RUNNING |
| ti2.state = State.QUEUED |
| ti3.state = State.RUNNING |
| assert ti3 in session |
| session.commit() |
| |
| assert 1 == ti1.get_num_running_task_instances(session=session) |
| assert 1 == ti2.get_num_running_task_instances(session=session) |
| assert 1 == ti3.get_num_running_task_instances(session=session) |
| |
| def test_get_num_running_task_instances_per_dagrun(self, create_task_instance, dag_maker): |
| session = settings.Session() |
| |
| with dag_maker(dag_id="test_dag"): |
| MockOperator.partial(task_id="task_1").expand_kwargs([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) |
| MockOperator.partial(task_id="task_2").expand_kwargs([{"a": 1, "b": 2}]) |
| MockOperator.partial(task_id="task_3").expand_kwargs([{"a": 1, "b": 2}]) |
| |
| dr1 = dag_maker.create_dagrun( |
| execution_date=timezone.utcnow(), state=DagRunState.RUNNING, run_id="run_id_1", session=session |
| ) |
| tis1 = {(ti.task_id, ti.map_index): ti for ti in dr1.task_instances} |
| print(f"tis1: {tis1}") |
| |
| dr2 = dag_maker.create_dagrun( |
| execution_date=timezone.utcnow(), state=DagRunState.RUNNING, run_id="run_id_2", session=session |
| ) |
| tis2 = {(ti.task_id, ti.map_index): ti for ti in dr2.task_instances} |
| |
| assert tis1[("task_1", 0)] in session |
| assert tis1[("task_1", 1)] in session |
| assert tis1[("task_2", 0)] in session |
| assert tis1[("task_3", 0)] in session |
| assert tis2[("task_1", 0)] in session |
| assert tis2[("task_1", 1)] in session |
| assert tis2[("task_2", 0)] in session |
| assert tis2[("task_3", 0)] in session |
| |
| tis1[("task_1", 0)].state = State.RUNNING |
| tis1[("task_1", 1)].state = State.QUEUED |
| tis1[("task_2", 0)].state = State.RUNNING |
| tis1[("task_3", 0)].state = State.RUNNING |
| tis2[("task_1", 0)].state = State.RUNNING |
| tis2[("task_1", 1)].state = State.QUEUED |
| tis2[("task_2", 0)].state = State.RUNNING |
| tis2[("task_3", 0)].state = State.RUNNING |
| |
| session.commit() |
| |
| assert 1 == tis1[("task_1", 0)].get_num_running_task_instances(session=session, same_dagrun=True) |
| assert 1 == tis1[("task_1", 1)].get_num_running_task_instances(session=session, same_dagrun=True) |
| assert 2 == tis1[("task_2", 0)].get_num_running_task_instances(session=session) |
| assert 1 == tis1[("task_3", 0)].get_num_running_task_instances(session=session, same_dagrun=True) |
| |
| assert 1 == tis2[("task_1", 0)].get_num_running_task_instances(session=session, same_dagrun=True) |
| assert 1 == tis2[("task_1", 1)].get_num_running_task_instances(session=session, same_dagrun=True) |
| assert 2 == tis2[("task_2", 0)].get_num_running_task_instances(session=session) |
| assert 1 == tis2[("task_3", 0)].get_num_running_task_instances(session=session, same_dagrun=True) |
| |
| def test_log_url(self, create_task_instance): |
| ti = create_task_instance(dag_id="my_dag", task_id="op", execution_date=timezone.datetime(2018, 1, 1)) |
| |
| expected_url = ( |
| "http://localhost:8080" |
| "/dags/my_dag/grid" |
| "?dag_run_id=test" |
| "&task_id=op" |
| "&map_index=-1" |
| "&tab=logs" |
| ) |
| assert ti.log_url == expected_url |
| |
| def test_mark_success_url(self, create_task_instance): |
| now = pendulum.now("Europe/Brussels") |
| ti = create_task_instance(dag_id="dag", task_id="op", execution_date=now) |
| query = urllib.parse.parse_qs( |
| urllib.parse.urlsplit(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True |
| ) |
| assert query["dag_id"][0] == "dag" |
| assert query["task_id"][0] == "op" |
| assert query["dag_run_id"][0] == "test" |
| assert ti.execution_date == now |
| |
| def test_overwrite_params_with_dag_run_conf(self, create_task_instance): |
| ti = create_task_instance() |
| dag_run = ti.dag_run |
| dag_run.conf = {"override": True} |
| ti.task.params = {"override": False} |
| |
| params = process_params(ti.task.dag, ti.task, dag_run, suppress_exception=False) |
| assert params["override"] is True |
| |
| def test_overwrite_params_with_dag_run_none(self, create_task_instance): |
| ti = create_task_instance() |
| ti.task.params = {"override": False} |
| |
| params = process_params(ti.task.dag, ti.task, None, suppress_exception=False) |
| assert params["override"] is False |
| |
| def test_overwrite_params_with_dag_run_conf_none(self, create_task_instance): |
| ti = create_task_instance() |
| dag_run = ti.dag_run |
| ti.task.params = {"override": False} |
| |
| params = process_params(ti.task.dag, ti.task, dag_run, suppress_exception=False) |
| assert params["override"] is False |
| |
| @pytest.mark.parametrize("use_native_obj", [True, False]) |
| @patch("airflow.models.taskinstance.send_email") |
| def test_email_alert(self, mock_send_email, dag_maker, use_native_obj): |
| with dag_maker(dag_id="test_failure_email", render_template_as_native_obj=use_native_obj): |
| task = BashOperator(task_id="test_email_alert", bash_command="exit 1", email="to") |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| |
| (email, title, body), _ = mock_send_email.call_args |
| assert email == "to" |
| assert "test_email_alert" in title |
| assert "test_email_alert" in body |
| assert "Try 1" in body |
| |
| @conf_vars( |
| { |
| ("email", "subject_template"): "/subject/path", |
| ("email", "html_content_template"): "/html_content/path", |
| } |
| ) |
| @patch("airflow.models.taskinstance.send_email") |
| def test_email_alert_with_config(self, mock_send_email, dag_maker): |
| with dag_maker(dag_id="test_failure_email"): |
| task = BashOperator( |
| task_id="test_email_alert_with_config", |
| bash_command="exit 1", |
| email="to", |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| |
| opener = mock_open(read_data="template: {{ti.task_id}}") |
| with patch("airflow.models.taskinstance.open", opener, create=True): |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| |
| (email, title, body), _ = mock_send_email.call_args |
| assert email == "to" |
| assert "template: test_email_alert_with_config" == title |
| assert "template: test_email_alert_with_config" == body |
| |
| @patch("airflow.models.taskinstance.send_email") |
| def test_email_alert_with_filenotfound_config(self, mock_send_email, dag_maker): |
| with dag_maker(dag_id="test_failure_email"): |
| task = BashOperator( |
| task_id="test_email_alert_with_config", |
| bash_command="exit 1", |
| email="to", |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| |
| # Run test when the template file is not found |
| opener = mock_open(read_data="template: {{ti.task_id}}") |
| opener.side_effect = FileNotFoundError |
| with patch("airflow.models.taskinstance.open", opener, create=True): |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| |
| (email_error, title_error, body_error), _ = mock_send_email.call_args |
| |
| # Rerun task without any error and no template file |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| |
| (email_default, title_default, body_default), _ = mock_send_email.call_args |
| |
| assert email_error == email_default == "to" |
| assert title_default == title_error |
| assert body_default == body_error |
| |
| @pytest.mark.parametrize("task_id", ["test_email_alert", "test_email_alert__1"]) |
| @patch("airflow.models.taskinstance.send_email") |
| def test_failure_mapped_taskflow(self, mock_send_email, dag_maker, session, task_id): |
| with dag_maker(session=session) as dag: |
| |
| @dag.task(email="to") |
| def test_email_alert(x): |
| raise RuntimeError("Fail please") |
| |
| test_email_alert.expand(x=["a", "b"]) # This is 'test_email_alert'. |
| test_email_alert.expand(x=[1, 2, 3]) # This is 'test_email_alert__1'. |
| |
| dr: DagRun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) |
| |
| ti = dr.get_task_instance(task_id, map_index=0, session=session) |
| assert ti is not None |
| |
| # The task will fail and trigger email reporting. |
| with pytest.raises(RuntimeError, match=r"^Fail please$"): |
| ti.run(session=session) |
| |
| (email, title, body), _ = mock_send_email.call_args |
| assert email == "to" |
| assert title == f"Airflow alert: <TaskInstance: test_dag.{task_id} test map_index=0 [failed]>" |
| assert body.startswith("Try 1") |
| assert "test_email_alert" in body |
| |
| tf = ( |
| session.query(TaskFail) |
| .filter_by(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) |
| .one_or_none() |
| ) |
| assert tf, "TaskFail was recorded" |
| |
| def test_set_duration(self): |
| task = EmptyOperator(task_id="op", email="test@test.test") |
| ti = TI(task=task) |
| ti.start_date = datetime.datetime(2018, 10, 1, 1) |
| ti.end_date = datetime.datetime(2018, 10, 1, 2) |
| ti.set_duration() |
| assert ti.duration == 3600 |
| |
| def test_set_duration_empty_dates(self): |
| task = EmptyOperator(task_id="op", email="test@test.test") |
| ti = TI(task=task) |
| ti.set_duration() |
| assert ti.duration is None |
| |
| def test_success_callback_no_race_condition(self, create_task_instance): |
| callback_wrapper = CallbackWrapper() |
| ti = create_task_instance( |
| on_success_callback=callback_wrapper.success_handler, |
| end_date=timezone.utcnow() + datetime.timedelta(days=10), |
| execution_date=timezone.utcnow(), |
| state=State.RUNNING, |
| ) |
| |
| session = settings.Session() |
| session.merge(ti) |
| session.commit() |
| |
| callback_wrapper.wrap_task_instance(ti) |
| ti._run_raw_task() |
| assert callback_wrapper.callback_ran |
| assert callback_wrapper.task_state_in_callback == State.SUCCESS |
| ti.refresh_from_db() |
| assert ti.state == State.SUCCESS |
| |
| def test_outlet_datasets(self, create_task_instance): |
| """ |
| Verify that when we have an outlet dataset on a task, and the task |
| completes successfully, a DatasetDagRunQueue is logged. |
| """ |
| from airflow.example_dags import example_datasets |
| from airflow.example_dags.example_datasets import dag1 |
| |
| session = settings.Session() |
| dagbag = DagBag(dag_folder=example_datasets.__file__) |
| dagbag.collect_dags(only_if_updated=False, safe_mode=False) |
| dagbag.sync_to_db(session=session) |
| run_id = str(uuid4()) |
| dr = DagRun(dag1.dag_id, run_id=run_id, run_type="anything") |
| session.merge(dr) |
| task = dag1.get_task("producing_task_1") |
| task.bash_command = "echo 1" # make it go faster |
| ti = TaskInstance(task, run_id=run_id) |
| session.merge(ti) |
| session.commit() |
| ti._run_raw_task() |
| ti.refresh_from_db() |
| assert ti.state == TaskInstanceState.SUCCESS |
| |
| # check that no other dataset events recorded |
| event = ( |
| session.query(DatasetEvent) |
| .join(DatasetEvent.dataset) |
| .filter(DatasetEvent.source_task_instance == ti) |
| .one() |
| ) |
| assert event |
| assert event.dataset |
| |
| # check that one queue record created for each dag that depends on dataset 1 |
| assert session.query(DatasetDagRunQueue.target_dag_id).filter_by( |
| dataset_id=event.dataset.id |
| ).order_by(DatasetDagRunQueue.target_dag_id).all() == [ |
| ("conditional_dataset_and_time_based_timetable",), |
| ("consume_1_and_2_with_dataset_expressions",), |
| ("consume_1_or_2_with_dataset_expressions",), |
| ("consume_1_or_both_2_and_3_with_dataset_expressions",), |
| ("dataset_consumes_1",), |
| ("dataset_consumes_1_and_2",), |
| ("dataset_consumes_1_never_scheduled",), |
| ] |
| |
| # check that one event record created for dataset1 and this TI |
| assert session.query(DatasetModel.uri).join(DatasetEvent.dataset).filter( |
| DatasetEvent.source_task_instance == ti |
| ).one() == ("s3://dag1/output_1.txt",) |
| |
| # check that the dataset event has an earlier timestamp than the DDRQ's |
| ddrq_timestamps = ( |
| session.query(DatasetDagRunQueue.created_at).filter_by(dataset_id=event.dataset.id).all() |
| ) |
| assert all([event.timestamp < ddrq_timestamp for (ddrq_timestamp,) in ddrq_timestamps]) |
| |
| def test_outlet_datasets_failed(self, create_task_instance): |
| """ |
| Verify that when we have an outlet dataset on a task, and the task |
| failed, a DatasetDagRunQueue is not logged, and a DatasetEvent is |
| not generated |
| """ |
| from tests.dags import test_datasets |
| from tests.dags.test_datasets import dag_with_fail_task |
| |
| session = settings.Session() |
| dagbag = DagBag(dag_folder=test_datasets.__file__) |
| dagbag.collect_dags(only_if_updated=False, safe_mode=False) |
| dagbag.sync_to_db(session=session) |
| run_id = str(uuid4()) |
| dr = DagRun(dag_with_fail_task.dag_id, run_id=run_id, run_type="anything") |
| session.merge(dr) |
| task = dag_with_fail_task.get_task("fail_task") |
| ti = TaskInstance(task, run_id=run_id) |
| session.merge(ti) |
| session.commit() |
| with pytest.raises(AirflowFailException): |
| ti._run_raw_task() |
| ti.refresh_from_db() |
| assert ti.state == TaskInstanceState.FAILED |
| |
| # check that no dagruns were queued |
| assert session.query(DatasetDagRunQueue).count() == 0 |
| |
| # check that no dataset events were generated |
| assert session.query(DatasetEvent).count() == 0 |
| |
| def test_mapped_current_state(self, dag_maker): |
| with dag_maker(dag_id="test_mapped_current_state") as _: |
| from airflow.decorators import task |
| |
| @task() |
| def raise_an_exception(placeholder: int): |
| if placeholder == 0: |
| raise AirflowFailException("failing task") |
| else: |
| pass |
| |
| _ = raise_an_exception.expand(placeholder=[0, 1]) |
| |
| tis = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances |
| for task_instance in tis: |
| if task_instance.map_index == 0: |
| with pytest.raises(AirflowFailException): |
| task_instance.run() |
| assert task_instance.current_state() == TaskInstanceState.FAILED |
| else: |
| task_instance.run() |
| assert task_instance.current_state() == TaskInstanceState.SUCCESS |
| |
| def test_outlet_datasets_skipped(self): |
| """ |
| Verify that when we have an outlet dataset on a task, and the task |
| is skipped, a DatasetDagRunQueue is not logged, and a DatasetEvent is |
| not generated |
| """ |
| from tests.dags import test_datasets |
| from tests.dags.test_datasets import dag_with_skip_task |
| |
| session = settings.Session() |
| dagbag = DagBag(dag_folder=test_datasets.__file__) |
| dagbag.collect_dags(only_if_updated=False, safe_mode=False) |
| dagbag.sync_to_db(session=session) |
| run_id = str(uuid4()) |
| dr = DagRun(dag_with_skip_task.dag_id, run_id=run_id, run_type="anything") |
| session.merge(dr) |
| task = dag_with_skip_task.get_task("skip_task") |
| ti = TaskInstance(task, run_id=run_id) |
| session.merge(ti) |
| session.commit() |
| ti._run_raw_task() |
| ti.refresh_from_db() |
| assert ti.state == TaskInstanceState.SKIPPED |
| |
| # check that no dagruns were queued |
| assert session.query(DatasetDagRunQueue).count() == 0 |
| |
| # check that no dataset events were generated |
| assert session.query(DatasetEvent).count() == 0 |
| |
| def test_outlet_dataset_extra(self, dag_maker, session): |
| from airflow.datasets import Dataset |
| |
| with dag_maker(schedule=None, session=session) as dag: |
| |
| @task(outlets=Dataset("test_outlet_dataset_extra_1")) |
| def write1(*, outlet_events): |
| outlet_events["test_outlet_dataset_extra_1"].extra = {"foo": "bar"} |
| |
| write1() |
| |
| def _write2_post_execute(context, _): |
| context["outlet_events"]["test_outlet_dataset_extra_2"].extra = {"x": 1} |
| |
| BashOperator( |
| task_id="write2", |
| bash_command=":", |
| outlets=Dataset("test_outlet_dataset_extra_2"), |
| post_execute=_write2_post_execute, |
| ) |
| |
| dr: DagRun = dag_maker.create_dagrun() |
| for ti in dr.get_task_instances(session=session): |
| ti.refresh_from_task(dag.get_task(ti.task_id)) |
| ti.run(session=session) |
| |
| events = dict(iter(session.execute(select(DatasetEvent.source_task_id, DatasetEvent)))) |
| assert set(events) == {"write1", "write2"} |
| |
| assert events["write1"].source_dag_id == dr.dag_id |
| assert events["write1"].source_run_id == dr.run_id |
| assert events["write1"].source_task_id == "write1" |
| assert events["write1"].dataset.uri == "test_outlet_dataset_extra_1" |
| assert events["write1"].extra == {"foo": "bar"} |
| |
| assert events["write2"].source_dag_id == dr.dag_id |
| assert events["write2"].source_run_id == dr.run_id |
| assert events["write2"].source_task_id == "write2" |
| assert events["write2"].dataset.uri == "test_outlet_dataset_extra_2" |
| assert events["write2"].extra == {"x": 1} |
| |
| def test_outlet_dataset_extra_ignore_different(self, dag_maker, session): |
| from airflow.datasets import Dataset |
| |
| with dag_maker(schedule=None, session=session): |
| |
| @task(outlets=Dataset("test_outlet_dataset_extra")) |
| def write(*, outlet_events): |
| outlet_events["test_outlet_dataset_extra"].extra = {"one": 1} |
| outlet_events["different_uri"].extra = {"foo": "bar"} # Will be silently dropped. |
| |
| write() |
| |
| dr: DagRun = dag_maker.create_dagrun() |
| dr.get_task_instance("write").run(session=session) |
| |
| event = session.scalars(select(DatasetEvent)).one() |
| assert event.source_dag_id == dr.dag_id |
| assert event.source_run_id == dr.run_id |
| assert event.source_task_id == "write" |
| assert event.extra == {"one": 1} |
| |
| def test_outlet_dataset_extra_yield(self, dag_maker, session): |
| from airflow.datasets import Dataset |
| from airflow.datasets.metadata import Metadata |
| |
| with dag_maker(schedule=None, session=session) as dag: |
| |
| @task(outlets=Dataset("test_outlet_dataset_extra_1")) |
| def write1(): |
| result = "write_1 result" |
| yield Metadata("test_outlet_dataset_extra_1", {"foo": "bar"}) |
| return result |
| |
| write1() |
| |
| def _write2_post_execute(context, result): |
| yield Metadata("test_outlet_dataset_extra_2", {"x": 1}) |
| |
| BashOperator( |
| task_id="write2", |
| bash_command=":", |
| outlets=Dataset("test_outlet_dataset_extra_2"), |
| post_execute=_write2_post_execute, |
| ) |
| |
| dr: DagRun = dag_maker.create_dagrun() |
| for ti in dr.get_task_instances(session=session): |
| ti.refresh_from_task(dag.get_task(ti.task_id)) |
| ti.run(session=session) |
| |
| xcom = session.scalars( |
| select(XCom).filter_by(dag_id=dr.dag_id, run_id=dr.run_id, task_id="write1", key="return_value") |
| ).one() |
| assert xcom.value == "write_1 result" |
| |
| events = dict(iter(session.execute(select(DatasetEvent.source_task_id, DatasetEvent)))) |
| assert set(events) == {"write1", "write2"} |
| |
| assert events["write1"].source_dag_id == dr.dag_id |
| assert events["write1"].source_run_id == dr.run_id |
| assert events["write1"].source_task_id == "write1" |
| assert events["write1"].dataset.uri == "test_outlet_dataset_extra_1" |
| assert events["write1"].extra == {"foo": "bar"} |
| |
| assert events["write2"].source_dag_id == dr.dag_id |
| assert events["write2"].source_run_id == dr.run_id |
| assert events["write2"].source_task_id == "write2" |
| assert events["write2"].dataset.uri == "test_outlet_dataset_extra_2" |
| assert events["write2"].extra == {"x": 1} |
| |
| def test_inlet_dataset_extra(self, dag_maker, session): |
| from airflow.datasets import Dataset |
| |
| read_task_evaluated = False |
| |
| with dag_maker(schedule=None, session=session): |
| |
| @task(outlets=Dataset("test_inlet_dataset_extra")) |
| def write(*, ti, outlet_events): |
| outlet_events["test_inlet_dataset_extra"].extra = {"from": ti.task_id} |
| |
| @task(inlets=Dataset("test_inlet_dataset_extra")) |
| def read(*, inlet_events): |
| second_event = inlet_events["test_inlet_dataset_extra"][1] |
| assert second_event.uri == "test_inlet_dataset_extra" |
| assert second_event.extra == {"from": "write2"} |
| |
| last_event = inlet_events["test_inlet_dataset_extra"][-1] |
| assert last_event.uri == "test_inlet_dataset_extra" |
| assert last_event.extra == {"from": "write3"} |
| |
| with pytest.raises(KeyError): |
| inlet_events["does_not_exist"] |
| with pytest.raises(IndexError): |
| inlet_events["test_inlet_dataset_extra"][5] |
| |
| # TODO: Support slices. |
| |
| nonlocal read_task_evaluated |
| read_task_evaluated = True |
| |
| [ |
| write.override(task_id="write1")(), |
| write.override(task_id="write2")(), |
| write.override(task_id="write3")(), |
| ] >> read() |
| |
| dr: DagRun = dag_maker.create_dagrun() |
| |
| # Run "write1", "write2", and "write3" (in this order). |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): |
| ti.run(session=session) |
| |
| # Run "read". |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| for ti in decision.schedulable_tis: |
| ti.run(session=session) |
| |
| # Should be done. |
| assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis |
| assert read_task_evaluated |
| |
| def test_changing_of_dataset_when_ddrq_is_already_populated(self, dag_maker): |
| """ |
| Test that when a task that produces dataset has ran, that changing the consumer |
| dag dataset will not cause primary key blank-out |
| """ |
| from airflow.datasets import Dataset |
| |
| with dag_maker(schedule=None, serialized=True) as dag1: |
| |
| @task(outlets=Dataset("test/1")) |
| def test_task1(): |
| print(1) |
| |
| test_task1() |
| |
| dr1 = dag_maker.create_dagrun() |
| test_task1 = dag1.get_task("test_task1") |
| |
| with dag_maker(dag_id="testdag", schedule=[Dataset("test/1")], serialized=True): |
| |
| @task |
| def test_task2(): |
| print(1) |
| |
| test_task2() |
| |
| ti = dr1.get_task_instance(task_id="test_task1") |
| ti.run() |
| # Change the dataset. |
| with dag_maker(dag_id="testdag", schedule=[Dataset("test2/1")], serialized=True): |
| |
| @task |
| def test_task2(): |
| print(1) |
| |
| test_task2() |
| |
| @staticmethod |
| def _test_previous_dates_setup( |
| schedule_interval: str | datetime.timedelta | None, |
| catchup: bool, |
| scenario: list[TaskInstanceState], |
| dag_maker, |
| ) -> list: |
| dag_id = "test_previous_dates" |
| with dag_maker(dag_id=dag_id, schedule=schedule_interval, catchup=catchup): |
| task = EmptyOperator(task_id="task") |
| |
| def get_test_ti(execution_date: pendulum.DateTime, state: str) -> TI: |
| dr = dag_maker.create_dagrun( |
| run_id=f"test__{execution_date.isoformat()}", |
| run_type=DagRunType.SCHEDULED, |
| state=state, |
| execution_date=execution_date, |
| start_date=pendulum.now("UTC"), |
| ) |
| ti = dr.task_instances[0] |
| ti.task = task |
| ti.set_state(state=State.SUCCESS, session=dag_maker.session) |
| return ti |
| |
| date = cast(pendulum.DateTime, pendulum.parse("2019-01-01T00:00:00+00:00")) |
| |
| ret = [] |
| |
| for idx, state in enumerate(scenario): |
| new_date = date.add(days=idx) |
| ti = get_test_ti(new_date, state) |
| ret.append(ti) |
| |
| return ret |
| |
| _prev_dates_param_list = [ |
| pytest.param("0 0 * * * ", True, id="cron/catchup"), |
| pytest.param("0 0 * * *", False, id="cron/no-catchup"), |
| pytest.param(None, True, id="no-sched/catchup"), |
| pytest.param(None, False, id="no-sched/no-catchup"), |
| pytest.param(datetime.timedelta(days=1), True, id="timedelta/catchup"), |
| pytest.param(datetime.timedelta(days=1), False, id="timedelta/no-catchup"), |
| ] |
| |
| @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) |
| def test_previous_ti(self, schedule_interval, catchup, dag_maker) -> None: |
| scenario = [State.SUCCESS, State.FAILED, State.SUCCESS] |
| |
| ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) |
| |
| assert ti_list[0].get_previous_ti() is None |
| |
| assert ti_list[2].get_previous_ti().run_id == ti_list[1].run_id |
| |
| assert ti_list[2].get_previous_ti().run_id != ti_list[0].run_id |
| |
| @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) |
| def test_previous_ti_success(self, schedule_interval, catchup, dag_maker) -> None: |
| scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS] |
| |
| ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) |
| |
| assert ti_list[0].get_previous_ti(state=State.SUCCESS) is None |
| assert ti_list[1].get_previous_ti(state=State.SUCCESS) is None |
| |
| assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id == ti_list[1].run_id |
| |
| assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id != ti_list[2].run_id |
| |
| @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) |
| def test_previous_execution_date_success(self, schedule_interval, catchup, dag_maker) -> None: |
| scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS] |
| |
| ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) |
| # vivify |
| for ti in ti_list: |
| ti.execution_date |
| |
| assert ti_list[0].get_previous_execution_date(state=State.SUCCESS) is None |
| assert ti_list[1].get_previous_execution_date(state=State.SUCCESS) is None |
| assert ti_list[3].get_previous_execution_date(state=State.SUCCESS) == ti_list[1].execution_date |
| assert ti_list[3].get_previous_execution_date(state=State.SUCCESS) != ti_list[2].execution_date |
| |
| @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list) |
| def test_previous_start_date_success(self, schedule_interval, catchup, dag_maker) -> None: |
| scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS] |
| |
| ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker) |
| |
| assert ti_list[0].get_previous_start_date(state=State.SUCCESS) is None |
| assert ti_list[1].get_previous_start_date(state=State.SUCCESS) is None |
| assert ti_list[3].get_previous_start_date(state=State.SUCCESS) == ti_list[1].start_date |
| assert ti_list[3].get_previous_start_date(state=State.SUCCESS) != ti_list[2].start_date |
| |
| def test_get_previous_start_date_none(self, dag_maker): |
| """ |
| Test that get_previous_start_date() can handle TaskInstance with no start_date. |
| """ |
| with dag_maker("test_get_previous_start_date_none", schedule=None) as dag: |
| task = EmptyOperator(task_id="op") |
| |
| day_1 = DEFAULT_DATE |
| day_2 = DEFAULT_DATE + datetime.timedelta(days=1) |
| |
| # Create a DagRun for day_1 and day_2. Calling ti_2.get_previous_start_date() |
| # should return the start_date of ti_1 (which is None because ti_1 was not run). |
| # It should not raise an error. |
| dagrun_1 = dag_maker.create_dagrun( |
| execution_date=day_1, |
| state=State.RUNNING, |
| run_type=DagRunType.MANUAL, |
| ) |
| |
| dagrun_2 = dag.create_dagrun( |
| execution_date=day_2, |
| state=State.RUNNING, |
| run_type=DagRunType.MANUAL, |
| ) |
| |
| ti_1 = dagrun_1.get_task_instance(task.task_id) |
| ti_2 = dagrun_2.get_task_instance(task.task_id) |
| ti_1.task = task |
| ti_2.task = task |
| |
| assert ti_2.get_previous_start_date() == ti_1.start_date |
| assert ti_1.start_date is None |
| |
| def test_context_triggering_dataset_events_none(self, session, create_task_instance): |
| ti = create_task_instance() |
| template_context = ti.get_template_context() |
| |
| assert ti in session |
| session.expunge_all() |
| |
| assert template_context["triggering_dataset_events"] == {} |
| |
| def test_context_triggering_dataset_events(self, create_dummy_dag, session): |
| ds1 = DatasetModel(id=1, uri="one") |
| ds2 = DatasetModel(id=2, uri="two") |
| session.add_all([ds1, ds2]) |
| session.commit() |
| |
| # it's easier to fake a manual run here |
| dag, task1 = create_dummy_dag( |
| dag_id="test_triggering_dataset_events", |
| schedule=None, |
| start_date=DEFAULT_DATE, |
| task_id="test_context", |
| with_dagrun_type=DagRunType.MANUAL, |
| session=session, |
| ) |
| dr = dag.create_dagrun( |
| run_id="test2", |
| run_type=DagRunType.DATASET_TRIGGERED, |
| execution_date=timezone.utcnow(), |
| state=None, |
| session=session, |
| ) |
| ds1_event = DatasetEvent(dataset_id=1) |
| ds2_event_1 = DatasetEvent(dataset_id=2) |
| ds2_event_2 = DatasetEvent(dataset_id=2) |
| dr.consumed_dataset_events.append(ds1_event) |
| dr.consumed_dataset_events.append(ds2_event_1) |
| dr.consumed_dataset_events.append(ds2_event_2) |
| session.commit() |
| |
| ti = dr.get_task_instance(task1.task_id, session=session) |
| ti.refresh_from_task(task1) |
| |
| # Check we run this in the same context as the actual task at runtime! |
| assert ti in session |
| session.expunge(ti) |
| session.expunge(dr) |
| |
| template_context = ti.get_template_context() |
| |
| assert template_context["triggering_dataset_events"] == { |
| "one": [ds1_event], |
| "two": [ds2_event_1, ds2_event_2], |
| } |
| |
| def test_pendulum_template_dates(self, create_task_instance): |
| ti = create_task_instance( |
| dag_id="test_pendulum_template_dates", |
| task_id="test_pendulum_template_dates_task", |
| schedule="0 12 * * *", |
| ) |
| |
| template_context = ti.get_template_context() |
| |
| assert isinstance(template_context["data_interval_start"], pendulum.DateTime) |
| assert isinstance(template_context["data_interval_end"], pendulum.DateTime) |
| |
| def test_template_render(self, create_task_instance): |
| ti = create_task_instance( |
| dag_id="test_template_render", |
| task_id="test_template_render_task", |
| schedule="0 12 * * *", |
| ) |
| template_context = ti.get_template_context() |
| result = ti.task.render_template("Task: {{ dag.dag_id }} -> {{ task.task_id }}", template_context) |
| assert result == "Task: test_template_render -> test_template_render_task" |
| |
| def test_template_render_deprecated(self, create_task_instance): |
| ti = create_task_instance( |
| dag_id="test_template_render", |
| task_id="test_template_render_task", |
| schedule="0 12 * * *", |
| ) |
| template_context = ti.get_template_context() |
| with pytest.deprecated_call(): |
| result = ti.task.render_template("Execution date: {{ execution_date }}", template_context) |
| assert result.startswith("Execution date: ") |
| |
| @pytest.mark.parametrize( |
| "content, expected_output", |
| [ |
| ('{{ conn.get("a_connection").host }}', "hostvalue"), |
| ('{{ conn.get("a_connection", "unused_fallback").host }}', "hostvalue"), |
| ('{{ conn.get("missing_connection", {"host": "fallback_host"}).host }}', "fallback_host"), |
| ("{{ conn.a_connection.host }}", "hostvalue"), |
| ("{{ conn.a_connection.login }}", "loginvalue"), |
| ("{{ conn.a_connection.password }}", "passwordvalue"), |
| ('{{ conn.a_connection.extra_dejson["extra__asana__workspace"] }}', "extra1"), |
| ("{{ conn.a_connection.extra_dejson.extra__asana__workspace }}", "extra1"), |
| ], |
| ) |
| def test_template_with_connection(self, content, expected_output, create_task_instance): |
| """ |
| Test the availability of variables in templates |
| """ |
| with create_session() as session: |
| clear_db_connections(add_default_connections_back=False) |
| merge_conn( |
| Connection( |
| conn_id="a_connection", |
| conn_type="a_type", |
| description="a_conn_description", |
| host="hostvalue", |
| login="loginvalue", |
| password="passwordvalue", |
| schema="schemavalues", |
| extra={ |
| "extra__asana__workspace": "extra1", |
| }, |
| ), |
| session, |
| ) |
| |
| ti = create_task_instance() |
| |
| context = ti.get_template_context() |
| result = ti.task.render_template(content, context) |
| assert result == expected_output |
| |
| @pytest.mark.parametrize( |
| "content, expected_output", |
| [ |
| ("{{ var.value.a_variable }}", "a test value"), |
| ('{{ var.value.get("a_variable") }}', "a test value"), |
| ('{{ var.value.get("a_variable", "unused_fallback") }}', "a test value"), |
| ('{{ var.value.get("missing_variable", "fallback") }}', "fallback"), |
| ], |
| ) |
| def test_template_with_variable(self, content, expected_output, create_task_instance): |
| """ |
| Test the availability of variables in templates |
| """ |
| Variable.set("a_variable", "a test value") |
| |
| ti = create_task_instance() |
| context = ti.get_template_context() |
| result = ti.task.render_template(content, context) |
| assert result == expected_output |
| |
| def test_template_with_variable_missing(self, create_task_instance): |
| """ |
| Test the availability of variables in templates |
| """ |
| ti = create_task_instance() |
| context = ti.get_template_context() |
| with pytest.raises(KeyError): |
| ti.task.render_template('{{ var.value.get("missing_variable") }}', context) |
| |
| @pytest.mark.parametrize( |
| "content, expected_output", |
| [ |
| ("{{ var.value.a_variable }}", '{\n "a": {\n "test": "value"\n }\n}'), |
| ('{{ var.json.a_variable["a"]["test"] }}', "value"), |
| ('{{ var.json.get("a_variable")["a"]["test"] }}', "value"), |
| ('{{ var.json.get("a_variable", {"a": {"test": "unused_fallback"}})["a"]["test"] }}', "value"), |
| ('{{ var.json.get("missing_variable", {"a": {"test": "fallback"}})["a"]["test"] }}', "fallback"), |
| ], |
| ) |
| def test_template_with_json_variable(self, content, expected_output, create_task_instance): |
| """ |
| Test the availability of variables in templates |
| """ |
| Variable.set("a_variable", {"a": {"test": "value"}}, serialize_json=True) |
| |
| ti = create_task_instance() |
| context = ti.get_template_context() |
| result = ti.task.render_template(content, context) |
| assert result == expected_output |
| |
| def test_template_with_json_variable_missing(self, create_task_instance): |
| ti = create_task_instance() |
| context = ti.get_template_context() |
| with pytest.raises(KeyError): |
| ti.task.render_template('{{ var.json.get("missing_variable") }}', context) |
| |
| @pytest.mark.parametrize( |
| ("field", "expected"), |
| [ |
| ("next_ds", "2016-01-01"), |
| ("next_ds_nodash", "20160101"), |
| ("prev_ds", "2015-12-31"), |
| ("prev_ds_nodash", "20151231"), |
| ("yesterday_ds", "2015-12-31"), |
| ("yesterday_ds_nodash", "20151231"), |
| ("tomorrow_ds", "2016-01-02"), |
| ("tomorrow_ds_nodash", "20160102"), |
| ], |
| ) |
| def test_deprecated_context(self, field, expected, create_task_instance): |
| ti = create_task_instance(execution_date=DEFAULT_DATE) |
| context = ti.get_template_context() |
| with pytest.deprecated_call() as recorder: |
| assert context[field] == expected |
| message_beginning = ( |
| f"Accessing {field!r} from the template is deprecated and " |
| f"will be removed in a future version." |
| ) |
| |
| recorded_message = [str(m.message) for m in recorder] |
| assert len(recorded_message) == 1 |
| assert recorded_message[0].startswith(message_beginning) |
| |
| def test_template_with_custom_timetable_deprecated_context(self, create_task_instance): |
| ti = create_task_instance( |
| start_date=DEFAULT_DATE, |
| timetable=AfterWorkdayTimetable(), |
| run_type=DagRunType.SCHEDULED, |
| execution_date=timezone.datetime(2021, 9, 6), |
| data_interval=(timezone.datetime(2021, 9, 6), timezone.datetime(2021, 9, 7)), |
| ) |
| context = ti.get_template_context() |
| with pytest.deprecated_call(): |
| assert context["execution_date"] == pendulum.DateTime(2021, 9, 6, tzinfo=TIMEZONE) |
| with pytest.deprecated_call(): |
| assert context["next_ds"] == "2021-09-07" |
| with pytest.deprecated_call(): |
| assert context["next_ds_nodash"] == "20210907" |
| with pytest.deprecated_call(): |
| assert context["next_execution_date"] == pendulum.DateTime(2021, 9, 7, tzinfo=TIMEZONE) |
| with pytest.deprecated_call(): |
| assert context["prev_ds"] is None, "Does not make sense for custom timetable" |
| with pytest.deprecated_call(): |
| assert context["prev_ds_nodash"] is None, "Does not make sense for custom timetable" |
| with pytest.deprecated_call(): |
| assert context["prev_execution_date"] is None, "Does not make sense for custom timetable" |
| |
| def test_execute_callback(self, create_task_instance): |
| called = False |
| |
| def on_execute_callable(context): |
| nonlocal called |
| called = True |
| assert context["dag_run"].dag_id == "test_dagrun_execute_callback" |
| |
| for i, callback_input in enumerate([[on_execute_callable], on_execute_callable]): |
| ti = create_task_instance( |
| dag_id=f"test_execute_callback_{i}", |
| on_execute_callback=callback_input, |
| state=State.RUNNING, |
| ) |
| |
| session = settings.Session() |
| |
| session.merge(ti) |
| session.commit() |
| |
| ti._run_raw_task() |
| assert called |
| ti.refresh_from_db() |
| assert ti.state == State.SUCCESS |
| |
| def test_finished_callbacks_handle_and_log_exception(self, caplog): |
| def on_finish_callable(context): |
| nonlocal called, completed |
| called = True |
| raise KeyError |
| completed = True |
| |
| for callback_input in [[on_finish_callable], on_finish_callable]: |
| called = completed = False |
| caplog.clear() |
| _run_finished_callback(callbacks=callback_input, context={}) |
| |
| assert called |
| assert not completed |
| callback_name = callback_input[0] if isinstance(callback_input, list) else callback_input |
| callback_name = qualname(callback_name).split(".")[-1] |
| assert "Executing on_finish_callable callback" in caplog.text |
| assert "Error when executing on_finish_callable callback" in caplog.text |
| |
| @provide_session |
| def test_handle_failure(self, create_dummy_dag, session=None): |
| start_date = timezone.datetime(2016, 6, 1) |
| clear_db_runs() |
| |
| from airflow.listeners.listener import get_listener_manager |
| |
| listener_callback_on_error = mock.MagicMock() |
| get_listener_manager().pm.hook.on_task_instance_failed = listener_callback_on_error |
| |
| mock_on_failure_1 = mock.MagicMock() |
| mock_on_failure_1.__name__ = "mock_on_failure_1" |
| mock_on_retry_1 = mock.MagicMock() |
| mock_on_retry_1.__name__ = "mock_on_retry_1" |
| dag, task1 = create_dummy_dag( |
| dag_id="test_handle_failure", |
| schedule=None, |
| start_date=start_date, |
| task_id="test_handle_failure_on_failure", |
| with_dagrun_type=DagRunType.MANUAL, |
| on_failure_callback=mock_on_failure_1, |
| on_retry_callback=mock_on_retry_1, |
| session=session, |
| ) |
| dr = dag.create_dagrun( |
| run_id="test2", |
| run_type=DagRunType.MANUAL, |
| execution_date=timezone.utcnow(), |
| state=None, |
| session=session, |
| ) |
| ti1 = dr.get_task_instance(task1.task_id, session=session) |
| ti1.task = task1 |
| |
| ti1.state = State.FAILED |
| error_message = "test failure handling" |
| ti1.handle_failure(error_message) |
| |
| # check that the listener callback was called, and that it can access the error |
| listener_callback_on_error.assert_called_once() |
| callback_args = listener_callback_on_error.call_args.kwargs |
| assert "error" in callback_args |
| assert callback_args["error"] == error_message |
| |
| context_arg_1 = mock_on_failure_1.call_args.args[0] |
| assert context_arg_1 |
| assert "task_instance" in context_arg_1 |
| mock_on_retry_1.assert_not_called() |
| |
| mock_on_failure_2 = mock.MagicMock() |
| mock_on_failure_2.__name__ = "mock_on_failure_2" |
| mock_on_retry_2 = mock.MagicMock() |
| mock_on_retry_2.__name__ = "mock_on_retry_2" |
| task2 = EmptyOperator( |
| task_id="test_handle_failure_on_retry", |
| on_failure_callback=mock_on_failure_2, |
| on_retry_callback=mock_on_retry_2, |
| retries=1, |
| dag=dag, |
| ) |
| ti2 = TI(task=task2, run_id=dr.run_id) |
| ti2.state = State.FAILED |
| session.add(ti2) |
| session.flush() |
| ti2.handle_failure("test retry handling") |
| |
| mock_on_failure_2.assert_not_called() |
| |
| context_arg_2 = mock_on_retry_2.call_args.args[0] |
| assert context_arg_2 |
| assert "task_instance" in context_arg_2 |
| |
| # test the scenario where normally we would retry but have been asked to fail |
| mock_on_failure_3 = mock.MagicMock() |
| mock_on_failure_3.__name__ = "mock_on_failure_3" |
| mock_on_retry_3 = mock.MagicMock() |
| mock_on_retry_3.__name__ = "mock_on_retry_3" |
| task3 = EmptyOperator( |
| task_id="test_handle_failure_on_force_fail", |
| on_failure_callback=mock_on_failure_3, |
| on_retry_callback=mock_on_retry_3, |
| retries=1, |
| dag=dag, |
| ) |
| ti3 = TI(task=task3, run_id=dr.run_id) |
| session.add(ti3) |
| session.flush() |
| ti3.state = State.FAILED |
| ti3.handle_failure("test force_fail handling", force_fail=True) |
| |
| context_arg_3 = mock_on_failure_3.call_args.args[0] |
| assert context_arg_3 |
| assert "task_instance" in context_arg_3 |
| mock_on_retry_3.assert_not_called() |
| |
| def test_handle_failure_updates_queued_task_try_number(self, dag_maker): |
| session = settings.Session() |
| with dag_maker(): |
| task = EmptyOperator(task_id="mytask", retries=1) |
| dr = dag_maker.create_dagrun() |
| ti = TI(task=task, run_id=dr.run_id) |
| ti.state = State.QUEUED |
| session.merge(ti) |
| session.flush() |
| assert ti.state == State.QUEUED |
| assert ti.try_number == 1 |
| ti.handle_failure("test queued ti", test_mode=True) |
| assert ti.state == State.UP_FOR_RETRY |
| # Assert that 'ti._try_number' is bumped from 0 to 1. This is the last/current try |
| assert ti._try_number == 1 |
| # Check 'ti.try_number' is bumped to 2. This is try_number for next run |
| assert ti.try_number == 2 |
| |
| @patch.object(Stats, "incr") |
| def test_handle_failure_no_task(self, Stats_incr, dag_maker): |
| """ |
| When a zombie is detected for a DAG with a parse error, we need to be able to run handle_failure |
| _without_ ti.task being set |
| """ |
| session = settings.Session() |
| with dag_maker(): |
| task = EmptyOperator(task_id="mytask", retries=1) |
| dr = dag_maker.create_dagrun() |
| ti = TI(task=task, run_id=dr.run_id) |
| ti = session.merge(ti) |
| ti.task = None |
| ti.state = State.QUEUED |
| session.flush() |
| expected_stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} |
| |
| assert ti.task is None, "Check critical pre-condition" |
| |
| assert ti.state == State.QUEUED |
| assert ti.try_number == 1 |
| |
| ti.handle_failure("test queued ti", test_mode=False) |
| assert ti.state == State.UP_FOR_RETRY |
| # Assert that 'ti._try_number' is bumped from 0 to 1. This is the last/current try |
| assert ti._try_number == 1 |
| # Check 'ti.try_number' is bumped to 2. This is try_number for next run |
| assert ti.try_number == 2 |
| |
| Stats_incr.assert_any_call("ti_failures", tags=expected_stats_tags) |
| Stats_incr.assert_any_call("operator_failures_EmptyOperator", tags=expected_stats_tags) |
| Stats_incr.assert_any_call( |
| "operator_failures", tags={**expected_stats_tags, "operator": "EmptyOperator"} |
| ) |
| |
| def test_handle_failure_task_undefined(self, create_task_instance): |
| """ |
| When the loaded taskinstance does not use refresh_from_task, the task may be undefined. |
| For example: |
| the DAG file has been deleted before executing _execute_task_callbacks |
| """ |
| ti = create_task_instance() |
| del ti.task |
| ti.handle_failure("test ti.task undefined") |
| |
| @provide_session |
| def test_handle_failure_fail_stop(self, create_dummy_dag, session=None): |
| start_date = timezone.datetime(2016, 6, 1) |
| clear_db_runs() |
| |
| dag, task1 = create_dummy_dag( |
| dag_id="test_handle_failure_fail_stop", |
| schedule=None, |
| start_date=start_date, |
| task_id="task1", |
| trigger_rule="all_success", |
| with_dagrun_type=DagRunType.MANUAL, |
| session=session, |
| fail_stop=True, |
| ) |
| dr = dag.create_dagrun( |
| run_id="test_ff", |
| run_type=DagRunType.MANUAL, |
| execution_date=timezone.utcnow(), |
| state=None, |
| session=session, |
| ) |
| |
| ti1 = dr.get_task_instance(task1.task_id, session=session) |
| ti1.task = task1 |
| ti1.state = State.SUCCESS |
| |
| states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED, State.DEFERRED] |
| tasks = [] |
| for i, state in enumerate(states): |
| op = EmptyOperator( |
| task_id=f"reg_Task{i}", |
| dag=dag, |
| ) |
| ti = TI(task=op, run_id=dr.run_id) |
| ti.state = state |
| session.add(ti) |
| tasks.append(ti) |
| |
| fail_task = EmptyOperator( |
| task_id="fail_Task", |
| dag=dag, |
| ) |
| ti_ff = TI(task=fail_task, run_id=dr.run_id) |
| ti_ff.state = State.FAILED |
| session.add(ti_ff) |
| session.flush() |
| ti_ff.handle_failure("test retry handling") |
| |
| assert ti1.state == State.SUCCESS |
| assert ti_ff.state == State.FAILED |
| exp_states = [State.FAILED, State.FAILED, State.SKIPPED, State.SKIPPED, State.SKIPPED] |
| for i in range(len(states)): |
| assert tasks[i].state == exp_states[i] |
| |
| def test_does_not_retry_on_airflow_fail_exception(self, dag_maker): |
| def fail(): |
| raise AirflowFailException("hopeless") |
| |
| with dag_maker(dag_id="test_does_not_retry_on_airflow_fail_exception"): |
| task = PythonOperator( |
| task_id="test_raise_airflow_fail_exception", |
| python_callable=fail, |
| retries=1, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| assert State.FAILED == ti.state |
| |
| def test_retries_on_other_exceptions(self, dag_maker): |
| def fail(): |
| raise AirflowException("maybe this will pass?") |
| |
| with dag_maker(dag_id="test_retries_on_other_exceptions"): |
| task = PythonOperator( |
| task_id="test_raise_other_exception", |
| python_callable=fail, |
| retries=1, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| with contextlib.suppress(AirflowException): |
| ti.run() |
| assert State.UP_FOR_RETRY == ti.state |
| |
| @patch.object(TaskInstance, "logger") |
| def test_stacktrace_on_failure_starts_with_task_execute_method(self, mock_get_log, dag_maker): |
| def fail(): |
| raise AirflowException("maybe this will pass?") |
| |
| with dag_maker(dag_id="test_retries_on_other_exceptions"): |
| task = PythonOperator( |
| task_id="test_raise_other_exception", |
| python_callable=fail, |
| retries=1, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| ti.task = task |
| mock_log = mock.Mock() |
| mock_get_log.return_value = mock_log |
| with pytest.raises(AirflowException): |
| ti.run() |
| mock_log.error.assert_called_once() |
| assert mock_log.error.call_args.args == ("Task failed with exception",) |
| exc_info = mock_log.error.call_args.kwargs["exc_info"] |
| filename = exc_info[2].tb_frame.f_code.co_filename |
| formatted_exc = format_exception(*exc_info) |
| assert sys.modules[TaskInstance.__module__].__file__ == filename, "".join(formatted_exc) |
| |
| def _env_var_check_callback(self): |
| assert "test_echo_env_variables" == os.environ["AIRFLOW_CTX_DAG_ID"] |
| assert "hive_in_python_op" == os.environ["AIRFLOW_CTX_TASK_ID"] |
| assert DEFAULT_DATE.isoformat() == os.environ["AIRFLOW_CTX_EXECUTION_DATE"] |
| assert DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE) == os.environ["AIRFLOW_CTX_DAG_RUN_ID"] |
| |
| def test_echo_env_variables(self, dag_maker): |
| with dag_maker( |
| "test_echo_env_variables", |
| start_date=DEFAULT_DATE, |
| end_date=DEFAULT_DATE + datetime.timedelta(days=10), |
| ): |
| op = PythonOperator(task_id="hive_in_python_op", python_callable=self._env_var_check_callback) |
| dr = dag_maker.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| external_trigger=False, |
| ) |
| ti = TI(task=op, run_id=dr.run_id) |
| ti.state = State.RUNNING |
| session = settings.Session() |
| session.merge(ti) |
| session.commit() |
| ti._run_raw_task() |
| ti.refresh_from_db() |
| assert ti.state == State.SUCCESS |
| |
| @pytest.mark.parametrize( |
| "code, expected_state", |
| [ |
| pytest.param(1, State.FAILED, id="code-positive-number"), |
| pytest.param(-1, State.FAILED, id="code-negative-number"), |
| pytest.param("error", State.FAILED, id="code-text"), |
| pytest.param(0, State.SUCCESS, id="code-zero"), |
| pytest.param(None, State.SUCCESS, id="code-none"), |
| ], |
| ) |
| def test_handle_system_exit_failed(self, dag_maker, code, expected_state): |
| with dag_maker(): |
| |
| def f(*args, **kwargs): |
| exit(code) |
| |
| task = PythonOperator(task_id="mytask", python_callable=f) |
| |
| dr = dag_maker.create_dagrun() |
| ti = TI(task=task, run_id=dr.run_id) |
| ti.state = State.RUNNING |
| session = settings.Session() |
| session.merge(ti) |
| session.commit() |
| |
| if expected_state == State.SUCCESS: |
| ctx = contextlib.nullcontext() |
| else: |
| ctx = pytest.raises(AirflowException, match=rf"Task failed due to SystemExit\({code}\)") |
| |
| with ctx: |
| ti._run_raw_task() |
| ti.refresh_from_db() |
| assert ti.state == expected_state |
| |
| def test_get_current_context_works_in_template(self, dag_maker): |
| def user_defined_macro(): |
| from airflow.operators.python import get_current_context |
| |
| get_current_context() |
| |
| with dag_maker( |
| "test_context_inside_template", |
| start_date=DEFAULT_DATE, |
| end_date=DEFAULT_DATE + datetime.timedelta(days=10), |
| user_defined_macros={"user_defined_macro": user_defined_macro}, |
| ): |
| |
| def foo(arg): |
| print(arg) |
| |
| PythonOperator( |
| task_id="context_inside_template", |
| python_callable=foo, |
| op_kwargs={"arg": "{{ user_defined_macro() }}"}, |
| ) |
| dagrun = dag_maker.create_dagrun() |
| tis = dagrun.get_task_instances() |
| ti: TaskInstance = next(x for x in tis if x.task_id == "context_inside_template") |
| ti._run_raw_task() |
| assert ti.state == State.SUCCESS |
| |
| @patch.object(Stats, "incr") |
| def test_task_stats(self, stats_mock, create_task_instance): |
| ti = create_task_instance( |
| dag_id="test_task_start_end_stats", |
| end_date=timezone.utcnow() + datetime.timedelta(days=10), |
| state=State.RUNNING, |
| ) |
| stats_mock.reset_mock() |
| |
| session = settings.Session() |
| session.merge(ti) |
| session.commit() |
| ti._run_raw_task() |
| ti.refresh_from_db() |
| stats_mock.assert_any_call( |
| f"ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}", |
| tags={"dag_id": ti.dag_id, "task_id": ti.task_id}, |
| ) |
| stats_mock.assert_any_call( |
| "ti.finish", |
| tags={"dag_id": ti.dag_id, "task_id": ti.task_id, "state": ti.state}, |
| ) |
| for state in State.task_states: |
| assert ( |
| call( |
| f"ti.finish.{ti.dag_id}.{ti.task_id}.{state}", |
| count=0, |
| tags={"dag_id": ti.dag_id, "task_id": ti.task_id}, |
| ) |
| in stats_mock.mock_calls |
| ) |
| assert ( |
| call( |
| "ti.finish", |
| count=0, |
| tags={"dag_id": ti.dag_id, "task_id": ti.task_id, "state": str(state)}, |
| ) |
| in stats_mock.mock_calls |
| ) |
| assert ( |
| call(f"ti.start.{ti.dag_id}.{ti.task_id}", tags={"dag_id": ti.dag_id, "task_id": ti.task_id}) |
| in stats_mock.mock_calls |
| ) |
| assert call("ti.start", tags={"dag_id": ti.dag_id, "task_id": ti.task_id}) in stats_mock.mock_calls |
| assert stats_mock.call_count == (2 * len(State.task_states)) + 7 |
| |
| def test_command_as_list(self, create_task_instance): |
| ti = create_task_instance() |
| ti.task.dag.fileloc = os.path.join(TEST_DAGS_FOLDER, "x.py") |
| assert ti.command_as_list() == [ |
| "airflow", |
| "tasks", |
| "run", |
| ti.dag_id, |
| ti.task_id, |
| ti.run_id, |
| "--subdir", |
| "DAGS_FOLDER/x.py", |
| ] |
| |
| def test_generate_command_default_param(self): |
| dag_id = "test_generate_command_default_param" |
| task_id = "task" |
| assert_command = ["airflow", "tasks", "run", dag_id, task_id, "run_1"] |
| generate_command = TI.generate_command(dag_id=dag_id, task_id=task_id, run_id="run_1") |
| assert assert_command == generate_command |
| |
| def test_generate_command_specific_param(self): |
| dag_id = "test_generate_command_specific_param" |
| task_id = "task" |
| assert_command = [ |
| "airflow", |
| "tasks", |
| "run", |
| dag_id, |
| task_id, |
| "run_1", |
| "--mark-success", |
| "--map-index", |
| "0", |
| ] |
| generate_command = TI.generate_command( |
| dag_id=dag_id, task_id=task_id, run_id="run_1", mark_success=True, map_index=0 |
| ) |
| assert assert_command == generate_command |
| |
| @provide_session |
| def test_get_rendered_template_fields(self, dag_maker, session=None): |
| with dag_maker("test-dag", session=session) as dag: |
| task = BashOperator(task_id="op1", bash_command="{{ task.task_id }}") |
| dag.fileloc = TEST_DAGS_FOLDER / "test_get_k8s_pod_yaml.py" |
| ti = dag_maker.create_dagrun().task_instances[0] |
| ti.task = task |
| |
| session.add(RenderedTaskInstanceFields(ti)) |
| session.flush() |
| |
| # Create new TI for the same Task |
| new_task = BashOperator(task_id="op12", bash_command="{{ task.task_id }}", dag=dag) |
| |
| new_ti = TI(task=new_task, run_id=ti.run_id) |
| new_ti.get_rendered_template_fields(session=session) |
| |
| assert "op1" == ti.task.bash_command |
| |
| # CleanUp |
| with create_session() as session: |
| session.query(RenderedTaskInstanceFields).delete() |
| |
| def test_set_state_up_for_retry(self, create_task_instance): |
| ti = create_task_instance(state=State.RUNNING) |
| |
| start_date = timezone.utcnow() |
| ti.start_date = start_date |
| |
| ti.set_state(State.UP_FOR_RETRY) |
| assert ti.state == State.UP_FOR_RETRY |
| assert ti.start_date == start_date, "Start date should have been left alone" |
| assert ti.start_date < ti.end_date |
| assert ti.duration > 0 |
| |
| def test_refresh_from_db(self, create_task_instance): |
| run_date = timezone.utcnow() |
| hybrid_props = ["task_display_name"] |
| expected_values = { |
| "task_id": "test_refresh_from_db_task", |
| "dag_id": "test_refresh_from_db_dag", |
| "run_id": "test", |
| "map_index": -1, |
| "start_date": run_date + datetime.timedelta(days=1), |
| "end_date": run_date + datetime.timedelta(days=1, seconds=1, milliseconds=234), |
| "duration": 1.234, |
| "state": State.SUCCESS, |
| "_try_number": 1, |
| "max_tries": 1, |
| "hostname": "some_unique_hostname", |
| "unixname": "some_unique_unixname", |
| "job_id": 1234, |
| "pool": "some_fake_pool_id", |
| "pool_slots": 25, |
| "queue": "some_queue_id", |
| "priority_weight": 123, |
| "operator": "some_custom_operator", |
| "custom_operator_name": "some_custom_operator", |
| "queued_dttm": run_date + datetime.timedelta(hours=1), |
| "rendered_map_index": None, |
| "queued_by_job_id": 321, |
| "pid": 123, |
| "executor": "some_executor", |
| "executor_config": {"Some": {"extra": "information"}}, |
| "external_executor_id": "some_executor_id", |
| "trigger_timeout": None, |
| "trigger_id": None, |
| "next_kwargs": None, |
| "next_method": None, |
| "updated_at": None, |
| "task_display_name": "Test Refresh from DB Task", |
| } |
| # Make sure we aren't missing any new value in our expected_values list. |
| expected_keys = {f"task_instance.{key.lstrip('_')}" for key in expected_values} |
| assert {str(c) for c in TI.__table__.columns} == expected_keys, ( |
| "Please add all non-foreign values of TaskInstance to this list. " |
| "This prevents refresh_from_db() from missing a field." |
| ) |
| |
| ti = create_task_instance( |
| task_id=expected_values["task_id"], |
| task_display_name=expected_values["task_display_name"], |
| dag_id=expected_values["dag_id"], |
| ) |
| for key, expected_value in expected_values.items(): |
| if key not in hybrid_props: |
| setattr(ti, key, expected_value) |
| with create_session() as session: |
| session.merge(ti) |
| session.commit() |
| |
| mock_task = mock.MagicMock() |
| mock_task.task_id = expected_values["task_id"] |
| mock_task.dag_id = expected_values["dag_id"] |
| |
| ti = TI(task=mock_task, run_id="test") |
| ti.refresh_from_db() |
| for key, expected_value in expected_values.items(): |
| assert hasattr(ti, key), f"Key {key} is missing in the TaskInstance." |
| if key not in hybrid_props: |
| assert ( |
| getattr(ti, key) == expected_value |
| ), f"Key: {key} had different values. Make sure it loads it in the refresh refresh_from_db()" |
| |
| def test_operator_field_with_serialization(self, create_task_instance): |
| ti = create_task_instance() |
| assert ti.task.task_type == "EmptyOperator" |
| assert ti.task.operator_name == "EmptyOperator" |
| |
| # Verify that ti.operator field renders correctly "without" Serialization |
| assert ti.operator == "EmptyOperator" |
| |
| serialized_op = SerializedBaseOperator.serialize_operator(ti.task) |
| deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) |
| assert deserialized_op.task_type == "EmptyOperator" |
| # Verify that ti.operator field renders correctly "with" Serialization |
| ser_ti = TI(task=deserialized_op, run_id=None) |
| assert ser_ti.operator == "EmptyOperator" |
| assert ser_ti.task.operator_name == "EmptyOperator" |
| |
| def test_clear_db_references(self, session, create_task_instance): |
| tables = [TaskFail, RenderedTaskInstanceFields, XCom] |
| ti = create_task_instance() |
| ti.note = "sample note" |
| |
| session.merge(ti) |
| session.commit() |
| for table in [TaskFail, RenderedTaskInstanceFields]: |
| session.add(table(ti)) |
| XCom.set(key="key", value="value", task_id=ti.task_id, dag_id=ti.dag_id, run_id=ti.run_id) |
| session.commit() |
| for table in tables: |
| assert session.query(table).count() == 1 |
| |
| filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) |
| ti_note = session.query(TaskInstanceNote).filter_by(**filter_kwargs).one() |
| assert ti_note.content == "sample note" |
| |
| ti.clear_db_references(session) |
| for table in tables: |
| assert session.query(table).count() == 0 |
| |
| assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None |
| |
| def test_skipped_task_call_on_skipped_callback(self, dag_maker): |
| def raise_skip_exception(): |
| raise AirflowSkipException |
| |
| callback_function = mock.MagicMock() |
| callback_function.__name__ = "callback_function" |
| |
| with dag_maker(dag_id="test_skipped_task"): |
| task = PythonOperator( |
| task_id="test_skipped_task", |
| python_callable=raise_skip_exception, |
| on_skipped_callback=callback_function, |
| ) |
| |
| dr = dag_maker.create_dagrun(execution_date=timezone.utcnow()) |
| ti = dr.task_instances[0] |
| ti.task = task |
| ti.run() |
| assert State.SKIPPED == ti.state |
| assert callback_function.called |
| |
| |
| @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) |
| @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) |
| def test_refresh_from_task(pool_override, queue_by_policy, monkeypatch): |
| default_queue = "test_queue" |
| expected_queue = queue_by_policy or default_queue |
| if queue_by_policy: |
| # Apply a dummy cluster policy to check if it is always applied |
| def mock_policy(task_instance: TaskInstance): |
| task_instance.queue = queue_by_policy |
| |
| monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", mock_policy) |
| |
| task = EmptyOperator( |
| task_id="empty", |
| queue=default_queue, |
| pool="test_pool1", |
| pool_slots=3, |
| priority_weight=10, |
| run_as_user="test", |
| retries=30, |
| executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}, |
| ) |
| ti = TI(task, run_id=None) |
| ti.refresh_from_task(task, pool_override=pool_override) |
| |
| assert ti.queue == expected_queue |
| |
| if pool_override: |
| assert ti.pool == pool_override |
| else: |
| assert ti.pool == task.pool |
| |
| assert ti.pool_slots == task.pool_slots |
| assert ti.priority_weight == task.priority_weight_total |
| assert ti.run_as_user == task.run_as_user |
| assert ti.max_tries == task.retries |
| assert ti.executor_config == task.executor_config |
| assert ti.operator == EmptyOperator.__name__ |
| |
| # Test that refresh_from_task does not reset ti.max_tries |
| expected_max_tries = task.retries + 10 |
| ti.max_tries = expected_max_tries |
| ti.refresh_from_task(task) |
| assert ti.max_tries == expected_max_tries |
| |
| |
| class TestRunRawTaskQueriesCount: |
| """ |
| These tests are designed to detect changes in the number of queries executed |
| when calling _run_raw_task |
| """ |
| |
| @staticmethod |
| def _clean(): |
| db.clear_db_runs() |
| db.clear_db_pools() |
| db.clear_db_dags() |
| db.clear_db_sla_miss() |
| db.clear_db_import_errors() |
| db.clear_db_datasets() |
| |
| def setup_method(self) -> None: |
| self._clean() |
| |
| def teardown_method(self) -> None: |
| self._clean() |
| |
| |
| @pytest.mark.parametrize("mode", ["poke", "reschedule"]) |
| @pytest.mark.parametrize("retries", [0, 1]) |
| def test_sensor_timeout(mode, retries, dag_maker): |
| """ |
| Test that AirflowSensorTimeout does not cause sensor to retry. |
| """ |
| |
| def timeout(): |
| raise AirflowSensorTimeout |
| |
| mock_on_failure = mock.MagicMock() |
| mock_on_failure.__name__ = "mock_on_failure" |
| with dag_maker(dag_id=f"test_sensor_timeout_{mode}_{retries}"): |
| PythonSensor( |
| task_id="test_raise_sensor_timeout", |
| python_callable=timeout, |
| on_failure_callback=mock_on_failure, |
| retries=retries, |
| mode=mode, |
| ) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| |
| with pytest.raises(AirflowSensorTimeout): |
| ti.run() |
| |
| assert mock_on_failure.called |
| assert ti.state == State.FAILED |
| |
| |
| @pytest.mark.parametrize("mode", ["poke", "reschedule"]) |
| @pytest.mark.parametrize("retries", [0, 1]) |
| def test_mapped_sensor_timeout(mode, retries, dag_maker): |
| """ |
| Test that AirflowSensorTimeout does not cause mapped sensor to retry. |
| """ |
| |
| def timeout(): |
| raise AirflowSensorTimeout |
| |
| mock_on_failure = mock.MagicMock() |
| mock_on_failure.__name__ = "mock_on_failure" |
| with dag_maker(dag_id=f"test_sensor_timeout_{mode}_{retries}"): |
| PythonSensor.partial( |
| task_id="test_raise_sensor_timeout", |
| python_callable=timeout, |
| on_failure_callback=mock_on_failure, |
| retries=retries, |
| ).expand(mode=[mode]) |
| ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0] |
| |
| with pytest.raises(AirflowSensorTimeout): |
| ti.run() |
| |
| assert mock_on_failure.called |
| assert ti.state == State.FAILED |
| |
| |
| @pytest.mark.parametrize("mode", ["poke", "reschedule"]) |
| @pytest.mark.parametrize("retries", [0, 1]) |
| def test_mapped_sensor_works(mode, retries, dag_maker): |
| """ |
| Test that mapped sensors reaches success state. |
| """ |
| |
| def timeout(ti): |
| return 1 |
| |
| with dag_maker(dag_id=f"test_sensor_timeout_{mode}_{retries}"): |
| PythonSensor.partial( |
| task_id="test_raise_sensor_timeout", |
| python_callable=timeout, |
| retries=retries, |
| ).expand(mode=[mode]) |
| ti = dag_maker.create_dagrun().task_instances[0] |
| |
| ti.run() |
| assert ti.state == State.SUCCESS |
| |
| |
| class TestTaskInstanceRecordTaskMapXComPush: |
| """Test TI.xcom_push() correctly records return values for task-mapping.""" |
| |
| def setup_class(self): |
| """Ensure we start fresh.""" |
| with create_session() as session: |
| session.query(TaskMap).delete() |
| |
| @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2}, "abc"]) |
| def test_not_recorded_if_leaf(self, dag_maker, xcom_value): |
| """Return value should not be recorded if there are no downstreams.""" |
| with dag_maker(dag_id="test_not_recorded_for_unused") as dag: |
| |
| @dag.task() |
| def push_something(): |
| return xcom_value |
| |
| push_something() |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| |
| @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2}, "abc"]) |
| def test_not_recorded_if_not_used(self, dag_maker, xcom_value): |
| """Return value should not be recorded if no downstreams are mapped.""" |
| with dag_maker(dag_id="test_not_recorded_for_unused") as dag: |
| |
| @dag.task() |
| def push_something(): |
| return xcom_value |
| |
| @dag.task() |
| def completely_different(): |
| pass |
| |
| push_something() >> completely_different() |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| |
| @pytest.mark.parametrize("xcom_1", [[1, 2, 3], {"a": 1, "b": 2}, "abc"]) |
| @pytest.mark.parametrize("xcom_4", [[1, 2, 3], {"a": 1, "b": 2}]) |
| def test_not_recorded_if_irrelevant(self, dag_maker, xcom_1, xcom_4): |
| """Return value should only be recorded if a mapped downstream uses the it.""" |
| with dag_maker(dag_id="test_not_recorded_for_unused") as dag: |
| |
| @dag.task() |
| def push_1(): |
| return xcom_1 |
| |
| @dag.task() |
| def push_2(): |
| return [-1, -2] |
| |
| @dag.task() |
| def push_3(): |
| return ["x", "y"] |
| |
| @dag.task() |
| def push_4(): |
| return xcom_4 |
| |
| @dag.task() |
| def show(arg1, arg2): |
| print(arg1, arg2) |
| |
| @task_group() |
| def tg(arg): |
| show(arg1=task_3, arg2=arg) |
| |
| task_3 = push_3() |
| show.partial(arg1=push_1()).expand(arg2=push_2()) |
| tg.expand(arg=push_4()) |
| |
| tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} |
| |
| tis["push_1"].run() |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| |
| tis["push_2"].run() |
| assert dag_maker.session.query(TaskMap).count() == 1 |
| |
| tis["push_3"].run() |
| assert dag_maker.session.query(TaskMap).count() == 1 |
| |
| tis["push_4"].run() |
| assert dag_maker.session.query(TaskMap).count() == 2 |
| |
| @pytest.mark.parametrize( |
| "return_value, exception_type, error_message", |
| [ |
| ("abc", UnmappableXComTypePushed, "unmappable return type 'str'"), |
| (None, XComForMappingNotPushed, "did not push XCom for task mapping"), |
| ], |
| ) |
| def test_expand_error_if_unmappable_type(self, dag_maker, return_value, exception_type, error_message): |
| """If an unmappable return value is used for expand(), fail the task that pushed the XCom.""" |
| with dag_maker(dag_id="test_expand_error_if_unmappable_type") as dag: |
| |
| @dag.task() |
| def push_something(): |
| return return_value |
| |
| @dag.task() |
| def pull_something(value): |
| print(value) |
| |
| pull_something.expand(value=push_something()) |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") |
| with pytest.raises(exception_type) as ctx: |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| assert ti.state == TaskInstanceState.FAILED |
| assert str(ctx.value) == error_message |
| |
| @pytest.mark.parametrize( |
| "return_value, exception_type, error_message", |
| [ |
| (123, UnmappableXComTypePushed, "unmappable return type 'int'"), |
| (None, XComForMappingNotPushed, "did not push XCom for task mapping"), |
| ], |
| ) |
| def test_expand_kwargs_error_if_unmappable_type( |
| self, |
| dag_maker, |
| return_value, |
| exception_type, |
| error_message, |
| ): |
| """If an unmappable return value is used for expand_kwargs(), fail the task that pushed the XCom.""" |
| with dag_maker(dag_id="test_expand_kwargs_error_if_unmappable_type") as dag: |
| |
| @dag.task() |
| def push(): |
| return return_value |
| |
| MockOperator.partial(task_id="pull").expand_kwargs(push()) |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") |
| with pytest.raises(exception_type) as ctx: |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| assert ti.state == TaskInstanceState.FAILED |
| assert str(ctx.value) == error_message |
| |
| @pytest.mark.parametrize( |
| "return_value, exception_type, error_message", |
| [ |
| (123, UnmappableXComTypePushed, "unmappable return type 'int'"), |
| (None, XComForMappingNotPushed, "did not push XCom for task mapping"), |
| ], |
| ) |
| def test_task_group_expand_error_if_unmappable_type( |
| self, |
| dag_maker, |
| return_value, |
| exception_type, |
| error_message, |
| ): |
| """If an unmappable return value is used , fail the task that pushed the XCom.""" |
| with dag_maker(dag_id="test_task_group_expand_error_if_unmappable_type") as dag: |
| |
| @dag.task() |
| def push(): |
| return return_value |
| |
| @task_group |
| def tg(arg): |
| MockOperator(task_id="pull", arg1=arg) |
| |
| tg.expand(arg=push()) |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") |
| with pytest.raises(exception_type) as ctx: |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| assert ti.state == TaskInstanceState.FAILED |
| assert str(ctx.value) == error_message |
| |
| @pytest.mark.parametrize( |
| "return_value, exception_type, error_message", |
| [ |
| (123, UnmappableXComTypePushed, "unmappable return type 'int'"), |
| (None, XComForMappingNotPushed, "did not push XCom for task mapping"), |
| ], |
| ) |
| def test_task_group_expand_kwargs_error_if_unmappable_type( |
| self, |
| dag_maker, |
| return_value, |
| exception_type, |
| error_message, |
| ): |
| """If an unmappable return value is used, fail the task that pushed the XCom.""" |
| with dag_maker(dag_id="test_task_group_expand_kwargs_error_if_unmappable_type") as dag: |
| |
| @dag.task() |
| def push(): |
| return return_value |
| |
| @task_group |
| def tg(arg): |
| MockOperator(task_id="pull", arg1=arg) |
| |
| tg.expand_kwargs(push()) |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") |
| with pytest.raises(exception_type) as ctx: |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| assert ti.state == TaskInstanceState.FAILED |
| assert str(ctx.value) == error_message |
| |
| @pytest.mark.parametrize( |
| "create_upstream", |
| [ |
| # The task returns an invalid expand_kwargs() input (a list[int] instead of list[dict]). |
| pytest.param(lambda: task(task_id="push")(lambda: [0])(), id="normal"), |
| # This task returns a list[dict] (correct), but we use map() to transform it to list[int] (wrong). |
| pytest.param(lambda: task(task_id="push")(lambda: [{"v": ""}])().map(lambda _: 0), id="mapped"), |
| ], |
| ) |
| def test_expand_kwargs_error_if_received_invalid(self, dag_maker, session, create_upstream): |
| with dag_maker(dag_id="test_expand_kwargs_error_if_received_invalid", session=session): |
| push_task = create_upstream() |
| |
| @task() |
| def pull(v): |
| print(v) |
| |
| pull.expand_kwargs(push_task) |
| |
| dr = dag_maker.create_dagrun() |
| |
| # Run "push". |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| assert decision.schedulable_tis |
| for ti in decision.schedulable_tis: |
| ti.run() |
| |
| # Run "pull". |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| assert decision.schedulable_tis |
| for ti in decision.schedulable_tis: |
| with pytest.raises(ValueError) as ctx: |
| ti.run() |
| assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[int]" |
| |
| @pytest.mark.parametrize( |
| "downstream, error_message", |
| [ |
| ("taskflow", "mapping already partial argument: arg2"), |
| ("classic", "unmappable or already specified argument: arg2"), |
| ], |
| ids=["taskflow", "classic"], |
| ) |
| @pytest.mark.parametrize("strict", [True, False], ids=["strict", "override"]) |
| def test_expand_kwargs_override_partial(self, dag_maker, session, downstream, error_message, strict): |
| class ClassicOperator(MockOperator): |
| def execute(self, context): |
| return (self.arg1, self.arg2) |
| |
| with dag_maker(dag_id="test_expand_kwargs_override_partial", session=session) as dag: |
| |
| @dag.task() |
| def push(): |
| return [{"arg1": "a"}, {"arg1": "b", "arg2": "c"}] |
| |
| push_task = push() |
| |
| ClassicOperator.partial(task_id="classic", arg2="d").expand_kwargs(push_task, strict=strict) |
| |
| @dag.task(task_id="taskflow") |
| def pull(arg1, arg2): |
| return (arg1, arg2) |
| |
| pull.partial(arg2="d").expand_kwargs(push_task, strict=strict) |
| |
| dr = dag_maker.create_dagrun() |
| next(ti for ti in dr.task_instances if ti.task_id == "push").run() |
| |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| tis = {(ti.task_id, ti.map_index, ti.state): ti for ti in decision.schedulable_tis} |
| assert sorted(tis) == [ |
| ("classic", 0, None), |
| ("classic", 1, None), |
| ("taskflow", 0, None), |
| ("taskflow", 1, None), |
| ] |
| |
| ti = tis[(downstream, 0, None)] |
| ti.run() |
| ti.xcom_pull(task_ids=downstream, map_indexes=0, session=session) == ["a", "d"] |
| |
| ti = tis[(downstream, 1, None)] |
| if strict: |
| with pytest.raises(TypeError) as ctx: |
| ti.run() |
| assert str(ctx.value) == error_message |
| else: |
| ti.run() |
| ti.xcom_pull(task_ids=downstream, map_indexes=1, session=session) == ["b", "c"] |
| |
| def test_error_if_upstream_does_not_push(self, dag_maker): |
| """Fail the upstream task if it fails to push the XCom used for task mapping.""" |
| with dag_maker(dag_id="test_not_recorded_for_unused") as dag: |
| |
| @dag.task(do_xcom_push=False) |
| def push_something(): |
| return [1, 2] |
| |
| @dag.task() |
| def pull_something(value): |
| print(value) |
| |
| pull_something.expand(value=push_something()) |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") |
| with pytest.raises(XComForMappingNotPushed) as ctx: |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| assert ti.state == TaskInstanceState.FAILED |
| assert str(ctx.value) == "did not push XCom for task mapping" |
| |
| @conf_vars({("core", "max_map_length"): "1"}) |
| def test_error_if_unmappable_length(self, dag_maker): |
| """If an unmappable return value is used to map, fail the task that pushed the XCom.""" |
| with dag_maker(dag_id="test_not_recorded_for_unused") as dag: |
| |
| @dag.task() |
| def push_something(): |
| return [1, 2] |
| |
| @dag.task() |
| def pull_something(value): |
| print(value) |
| |
| pull_something.expand(value=push_something()) |
| |
| ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") |
| with pytest.raises(UnmappableXComLengthPushed) as ctx: |
| ti.run() |
| |
| assert dag_maker.session.query(TaskMap).count() == 0 |
| assert ti.state == TaskInstanceState.FAILED |
| assert str(ctx.value) == "unmappable return value length: 2 > 1" |
| |
| @pytest.mark.parametrize( |
| "xcom_value, expected_length, expected_keys", |
| [ |
| ([1, 2, 3], 3, None), |
| ({"a": 1, "b": 2}, 2, ["a", "b"]), |
| ], |
| ) |
| def test_written_task_map(self, dag_maker, xcom_value, expected_length, expected_keys): |
| """Return value should be recorded in TaskMap if it's used by a downstream to map.""" |
| with dag_maker(dag_id="test_written_task_map") as dag: |
| |
| @dag.task() |
| def push_something(): |
| return xcom_value |
| |
| @dag.task() |
| def pull_something(value): |
| print(value) |
| |
| pull_something.expand(value=push_something()) |
| |
| dag_run = dag_maker.create_dagrun() |
| ti = next(ti for ti in dag_run.task_instances if ti.task_id == "push_something") |
| ti.run() |
| |
| task_map = dag_maker.session.query(TaskMap).one() |
| assert task_map.dag_id == "test_written_task_map" |
| assert task_map.task_id == "push_something" |
| assert task_map.run_id == dag_run.run_id |
| assert task_map.map_index == -1 |
| assert task_map.length == expected_length |
| assert task_map.keys == expected_keys |
| |
| def test_no_error_on_changing_from_non_mapped_to_mapped(self, dag_maker, session): |
| """If a task changes from non-mapped to mapped, don't fail on integrity error.""" |
| with dag_maker(dag_id="test_no_error_on_changing_from_non_mapped_to_mapped") as dag: |
| |
| @dag.task() |
| def add_one(x): |
| return [x + 1] |
| |
| @dag.task() |
| def add_two(x): |
| return x + 2 |
| |
| task1 = add_one(2) |
| add_two.expand(x=task1) |
| |
| dr = dag_maker.create_dagrun() |
| ti = dr.get_task_instance(task_id="add_one") |
| ti.run() |
| assert ti.state == TaskInstanceState.SUCCESS |
| dag._remove_task("add_one") |
| with dag: |
| task1 = add_one.expand(x=[1, 2, 3]).operator |
| serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) |
| |
| dr.dag = serialized_dag |
| dr.verify_integrity(session=session) |
| ti = dr.get_task_instance(task_id="add_one") |
| assert ti.state == TaskInstanceState.REMOVED |
| dag.clear() |
| ti.refresh_from_task(task1) |
| # This should not raise an integrity error |
| dr.task_instance_scheduling_decisions() |
| |
| |
| class TestMappedTaskInstanceReceiveValue: |
| @pytest.mark.parametrize( |
| "literal, expected_outputs", |
| [ |
| pytest.param([1, 2, 3], [1, 2, 3], id="list"), |
| pytest.param({"a": 1, "b": 2}, [("a", 1), ("b", 2)], id="dict"), |
| ], |
| ) |
| def test_map_literal(self, literal, expected_outputs, dag_maker, session): |
| outputs = [] |
| |
| with dag_maker(dag_id="literal", session=session) as dag: |
| |
| @dag.task |
| def show(value): |
| outputs.append(value) |
| |
| show.expand(value=literal) |
| |
| dag_run = dag_maker.create_dagrun() |
| show_task = dag.get_task("show") |
| mapped_tis = ( |
| session.query(TI) |
| .filter_by(task_id="show", dag_id=dag_run.dag_id, run_id=dag_run.run_id) |
| .order_by(TI.map_index) |
| .all() |
| ) |
| assert len(mapped_tis) == len(literal) |
| |
| for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): |
| ti.refresh_from_task(show_task) |
| ti.run() |
| assert outputs == expected_outputs |
| |
| @pytest.mark.parametrize( |
| "upstream_return, expected_outputs", |
| [ |
| pytest.param([1, 2, 3], [1, 2, 3], id="list"), |
| pytest.param({"a": 1, "b": 2}, [("a", 1), ("b", 2)], id="dict"), |
| ], |
| ) |
| def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session): |
| outputs = [] |
| |
| with dag_maker(dag_id="xcom", session=session) as dag: |
| |
| @dag.task |
| def emit(): |
| return upstream_return |
| |
| @dag.task |
| def show(value): |
| outputs.append(value) |
| |
| show.expand(value=emit()) |
| |
| dag_run = dag_maker.create_dagrun() |
| emit_ti = dag_run.get_task_instance("emit", session=session) |
| emit_ti.refresh_from_task(dag.get_task("emit")) |
| emit_ti.run() |
| |
| show_task = dag.get_task("show") |
| mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) |
| assert max_map_index + 1 == len(mapped_tis) == len(upstream_return) |
| |
| for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): |
| ti.refresh_from_task(show_task) |
| ti.run() |
| assert outputs == expected_outputs |
| |
| def test_map_product(self, dag_maker, session): |
| outputs = [] |
| |
| with dag_maker(dag_id="product", session=session) as dag: |
| |
| @dag.task |
| def emit_numbers(): |
| return [1, 2] |
| |
| @dag.task |
| def emit_letters(): |
| return {"a": "x", "b": "y", "c": "z"} |
| |
| @dag.task |
| def show(number, letter): |
| outputs.append((number, letter)) |
| |
| show.expand(number=emit_numbers(), letter=emit_letters()) |
| |
| dag_run = dag_maker.create_dagrun() |
| for task_id in ["emit_numbers", "emit_letters"]: |
| ti = dag_run.get_task_instance(task_id, session=session) |
| ti.refresh_from_task(dag.get_task(task_id)) |
| ti.run() |
| |
| show_task = dag.get_task("show") |
| mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) |
| assert max_map_index + 1 == len(mapped_tis) == 6 |
| |
| for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): |
| ti.refresh_from_task(show_task) |
| ti.run() |
| assert outputs == [ |
| (1, ("a", "x")), |
| (1, ("b", "y")), |
| (1, ("c", "z")), |
| (2, ("a", "x")), |
| (2, ("b", "y")), |
| (2, ("c", "z")), |
| ] |
| |
| def test_map_product_same(self, dag_maker, session): |
| """Test a mapped task can refer to the same source multiple times.""" |
| outputs = [] |
| |
| with dag_maker(dag_id="product_same", session=session) as dag: |
| |
| @dag.task |
| def emit_numbers(): |
| return [1, 2] |
| |
| @dag.task |
| def show(a, b): |
| outputs.append((a, b)) |
| |
| emit_task = emit_numbers() |
| show.expand(a=emit_task, b=emit_task) |
| |
| dag_run = dag_maker.create_dagrun() |
| ti = dag_run.get_task_instance("emit_numbers", session=session) |
| ti.refresh_from_task(dag.get_task("emit_numbers")) |
| ti.run() |
| |
| show_task = dag.get_task("show") |
| with pytest.raises(NotFullyPopulated): |
| assert show_task.get_parse_time_mapped_ti_count() |
| mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) |
| assert max_map_index + 1 == len(mapped_tis) == 4 |
| |
| for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): |
| ti.refresh_from_task(show_task) |
| ti.run() |
| assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] |
| |
| def test_map_literal_cross_product(self, dag_maker, session): |
| """Test a mapped task with literal cross product args expand properly.""" |
| outputs = [] |
| |
| with dag_maker(dag_id="product_same_types", session=session) as dag: |
| |
| @dag.task |
| def show(a, b): |
| outputs.append((a, b)) |
| |
| show.expand(a=[2, 4, 8], b=[5, 10]) |
| |
| dag_run = dag_maker.create_dagrun() |
| |
| show_task = dag.get_task("show") |
| assert show_task.get_parse_time_mapped_ti_count() == 6 |
| mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) |
| assert len(mapped_tis) == 0 # Expanded at parse! |
| assert max_map_index == 5 |
| |
| tis = ( |
| session.query(TaskInstance) |
| .filter( |
| TaskInstance.dag_id == dag.dag_id, |
| TaskInstance.task_id == "show", |
| TaskInstance.run_id == dag_run.run_id, |
| ) |
| .order_by(TaskInstance.map_index) |
| .all() |
| ) |
| for ti in tis: |
| ti.refresh_from_task(show_task) |
| ti.run() |
| assert outputs == [(2, 5), (2, 10), (4, 5), (4, 10), (8, 5), (8, 10)] |
| |
| def test_map_in_group(self, tmp_path: pathlib.Path, dag_maker, session): |
| out = tmp_path.joinpath("out") |
| out.touch() |
| |
| with dag_maker(dag_id="in_group", session=session) as dag: |
| |
| @dag.task |
| def envs(): |
| return [{"VAR1": "FOO"}, {"VAR1": "BAR"}] |
| |
| @dag.task |
| def cmds(): |
| return [f'echo "hello $VAR1" >> {out}', f'echo "goodbye $VAR1" >> {out}'] |
| |
| with TaskGroup(group_id="dynamic"): |
| BashOperator.partial(task_id="bash", do_xcom_push=False).expand( |
| env=envs(), |
| bash_command=cmds(), |
| ) |
| |
| dag_run: DagRun = dag_maker.create_dagrun() |
| original_tis = {ti.task_id: ti for ti in dag_run.get_task_instances(session=session)} |
| |
| for task_id in ["dynamic.envs", "dynamic.cmds"]: |
| ti = original_tis[task_id] |
| ti.refresh_from_task(dag.get_task(task_id)) |
| ti.run() |
| |
| bash_task = dag.get_task("dynamic.bash") |
| mapped_bash_tis, max_map_index = bash_task.expand_mapped_task(dag_run.run_id, session=session) |
| assert max_map_index == 3 # 2 * 2 mapped tasks. |
| for ti in sorted(mapped_bash_tis, key=operator.attrgetter("map_index")): |
| ti.refresh_from_task(bash_task) |
| ti.run() |
| |
| with out.open() as f: |
| out_lines = [line.strip() for line in f] |
| assert out_lines == ["hello FOO", "goodbye FOO", "hello BAR", "goodbye BAR"] |
| |
| |
| def _get_lazy_xcom_access_expected_sql_lines() -> list[str]: |
| backend = os.environ.get("BACKEND") |
| if backend == "mysql": |
| return [ |
| "SELECT xcom.value", |
| "FROM xcom", |
| "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " |
| "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.`key` = 'xxx'", |
| ] |
| elif backend == "postgres": |
| return [ |
| "SELECT xcom.value", |
| "FROM xcom", |
| "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " |
| "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.key = 'xxx'", |
| ] |
| elif backend == "sqlite": |
| return [ |
| "SELECT xcom.value", |
| "FROM xcom", |
| "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " |
| "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.\"key\" = 'xxx'", |
| ] |
| else: |
| raise RuntimeError(f"unknown backend {backend!r}") |
| |
| |
| def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session): |
| with dag_maker(session=session): |
| EmptyOperator(task_id="t") |
| |
| run: DagRun = dag_maker.create_dagrun() |
| run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session) |
| |
| with set_current_task_instance_session(session=session): |
| original = LazyXComSelectSequence.from_select( |
| select(XCom.value).filter_by( |
| dag_id=run.dag_id, |
| run_id=run.run_id, |
| task_id="t", |
| map_index=-1, |
| key="xxx", |
| ), |
| order_by=(), |
| ) |
| processed = pickle.loads(pickle.dumps(original)) |
| |
| # After the object went through pickling, the underlying ORM query should be |
| # replaced by one backed by a literal SQL string with all variables binded. |
| sql_lines = [line.strip() for line in str(processed._select_asc.compile(None)).splitlines()] |
| assert sql_lines == _get_lazy_xcom_access_expected_sql_lines() |
| |
| assert len(processed) == 1 |
| assert list(processed) == [123] |
| |
| |
| @mock.patch("airflow.models.taskinstance.XCom.deserialize_value", side_effect=XCom.deserialize_value) |
| def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session): |
| """Ensure we access XCom lazily when pulling from a mapped operator.""" |
| with dag_maker(dag_id="test_xcom", session=session): |
| # Use the private _expand() method to avoid the empty kwargs check. |
| # We don't care about how the operator runs here, only its presence. |
| task_1 = EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, strict=False) |
| EmptyOperator(task_id="task_2") |
| |
| dagrun = dag_maker.create_dagrun() |
| |
| ti_1_0 = dagrun.get_task_instance("task_1", session=session) |
| ti_1_0.map_index = 0 |
| ti_1_1 = session.merge(TaskInstance(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) |
| session.flush() |
| |
| ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session) |
| ti_1_1.xcom_push(key=XCOM_RETURN_KEY, value="b", session=session) |
| |
| ti_2 = dagrun.get_task_instance("task_2", session=session) |
| |
| # Simply pulling the joined XCom value should not deserialize. |
| joined = ti_2.xcom_pull("task_1", session=session) |
| assert isinstance(joined, LazyXComSelectSequence) |
| assert mock_deserialize_value.call_count == 0 |
| |
| # Only when we go through the iterable does deserialization happen. |
| it = iter(joined) |
| assert next(it) == "a" |
| assert mock_deserialize_value.call_count == 1 |
| assert next(it) == "b" |
| assert mock_deserialize_value.call_count == 2 |
| with pytest.raises(StopIteration): |
| next(it) |
| |
| |
| def test_ti_mapped_depends_on_mapped_xcom_arg(dag_maker, session): |
| with dag_maker(session=session) as dag: |
| |
| @dag.task |
| def add_one(x): |
| return x + 1 |
| |
| two_three_four = add_one.expand(x=[1, 2, 3]) |
| add_one.expand(x=two_three_four) |
| |
| dagrun = dag_maker.create_dagrun() |
| for map_index in range(3): |
| ti = dagrun.get_task_instance("add_one", map_index=map_index, session=session) |
| ti.refresh_from_task(dag.get_task("add_one")) |
| ti.run() |
| |
| task_345 = dag.get_task("add_one__1") |
| for ti in task_345.expand_mapped_task(dagrun.run_id, session=session)[0]: |
| ti.refresh_from_task(task_345) |
| ti.run() |
| |
| query = XCom.get_many(run_id=dagrun.run_id, task_ids=["add_one__1"], session=session) |
| assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5] |
| |
| |
| def test_mapped_upstream_return_none_should_skip(dag_maker, session): |
| results = set() |
| |
| with dag_maker(dag_id="test_mapped_upstream_return_none_should_skip", session=session) as dag: |
| |
| @dag.task() |
| def transform(value): |
| if value == "b": # Now downstream doesn't map against this! |
| return None |
| return value |
| |
| @dag.task() |
| def pull(value): |
| results.add(value) |
| |
| original = ["a", "b", "c"] |
| transformed = transform.expand(value=original) # ["a", None, "c"] |
| pull.expand(value=transformed) # ["a", "c"] |
| |
| dr = dag_maker.create_dagrun() |
| |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} |
| assert sorted(tis) == [("transform", 0), ("transform", 1), ("transform", 2)] |
| for ti in tis.values(): |
| ti.run() |
| |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} |
| assert sorted(tis) == [("pull", 0), ("pull", 1)] |
| for ti in tis.values(): |
| ti.run() |
| |
| assert results == {"a", "c"} |
| |
| |
| def test_expand_non_templated_field(dag_maker, session): |
| """Test expand on non-templated fields sets upstream deps properly.""" |
| |
| class SimpleBashOperator(BashOperator): |
| template_fields = () |
| |
| with dag_maker(dag_id="product_same_types", session=session) as dag: |
| |
| @dag.task |
| def get_extra_env(): |
| return [{"foo": "bar"}, {"foo": "biz"}] |
| |
| SimpleBashOperator.partial(task_id="echo", bash_command="echo $FOO").expand(env=get_extra_env()) |
| |
| dag_maker.create_dagrun() |
| |
| echo_task = dag.get_task("echo") |
| assert "get_extra_env" in echo_task.upstream_task_ids |
| |
| |
| def test_mapped_task_does_not_error_in_mini_scheduler_if_upstreams_are_not_done(dag_maker, caplog, session): |
| """ |
| This tests that when scheduling child tasks of a task and there's a mapped downstream task, |
| if the mapped downstream task has upstreams that are not yet done, the mapped downstream task is |
| not marked as `upstream_failed' |
| """ |
| with dag_maker() as dag: |
| |
| @dag.task |
| def second_task(): |
| return [0, 1, 2] |
| |
| @dag.task |
| def first_task(): |
| print(2) |
| |
| @dag.task |
| def middle_task(id): |
| return id |
| |
| middle = middle_task.expand(id=second_task()) |
| |
| @dag.task |
| def last_task(): |
| print(3) |
| |
| [first_task(), middle] >> last_task() |
| |
| dag_run = dag_maker.create_dagrun() |
| first_ti = dag_run.get_task_instance(task_id="first_task") |
| second_ti = dag_run.get_task_instance(task_id="second_task") |
| first_ti.state = State.SUCCESS |
| second_ti.state = State.RUNNING |
| session.merge(first_ti) |
| session.merge(second_ti) |
| session.commit() |
| first_ti.schedule_downstream_tasks(session=session) |
| middle_ti = dag_run.get_task_instance(task_id="middle_task") |
| assert middle_ti.state != State.UPSTREAM_FAILED |
| assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text |
| |
| |
| def test_empty_operator_is_not_considered_in_mini_scheduler(dag_maker, caplog, session): |
| """ |
| This tests verify that operators with inherits_from_empty_operator are not considered by mini scheduler. |
| Such operators should not run on workers thus the mini scheduler optimization should skip them and not |
| submit them directly to worker. |
| """ |
| with dag_maker() as dag: |
| |
| @dag.task |
| def first_task(): |
| print(2) |
| |
| @dag.task |
| def second_task(): |
| print(2) |
| |
| third_task = EmptyOperator(task_id="third_task") |
| forth_task = EmptyOperator(task_id="forth_task", on_success_callback=lambda x: print("hi")) |
| |
| first_task() >> [second_task(), third_task, forth_task] |
| dag_run = dag_maker.create_dagrun() |
| first_ti = dag_run.get_task_instance(task_id="first_task") |
| second_ti = dag_run.get_task_instance(task_id="second_task") |
| third_ti = dag_run.get_task_instance(task_id="third_task") |
| forth_ti = dag_run.get_task_instance(task_id="forth_task") |
| first_ti.state = State.SUCCESS |
| second_ti.state = State.NONE |
| third_ti.state = State.NONE |
| forth_ti.state = State.NONE |
| session.merge(first_ti) |
| session.merge(second_ti) |
| session.merge(third_ti) |
| session.merge(forth_ti) |
| session.commit() |
| first_ti.schedule_downstream_tasks(session=session) |
| second_task = dag_run.get_task_instance(task_id="second_task") |
| third_task = dag_run.get_task_instance(task_id="third_task") |
| forth_task = dag_run.get_task_instance(task_id="forth_task") |
| assert second_task.state == State.SCHEDULED |
| assert third_task.state == State.NONE |
| assert forth_task.state == State.SCHEDULED |
| assert "2 downstream tasks scheduled from follow-on schedule" in caplog.text |
| |
| |
| def test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker, caplog, session): |
| """Test that mini scheduler expands mapped task""" |
| with dag_maker() as dag: |
| |
| @dag.task |
| def second_task(): |
| return [0, 1, 2] |
| |
| @dag.task |
| def first_task(): |
| print(2) |
| |
| @dag.task |
| def middle_task(id): |
| return id |
| |
| middle = middle_task.expand(id=second_task()) |
| |
| @dag.task |
| def last_task(): |
| print(3) |
| |
| [first_task(), middle] >> last_task() |
| |
| dr = dag_maker.create_dagrun() |
| |
| first_ti = dr.get_task_instance(task_id="first_task") |
| first_ti.state = State.SUCCESS |
| session.merge(first_ti) |
| session.commit() |
| second_task = dag.get_task("second_task") |
| second_ti = dr.get_task_instance(task_id="second_task") |
| second_ti.refresh_from_task(second_task) |
| second_ti.run() |
| second_ti.schedule_downstream_tasks(session=session) |
| for i in range(3): |
| middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i) |
| assert middle_ti.state == State.SCHEDULED |
| assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text |
| |
| |
| def test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker, session): |
| with dag_maker(session=session): |
| |
| @task |
| def generate() -> list[list[int]]: |
| return [] |
| |
| @task |
| def a_sum(numbers: list[int]) -> int: |
| return sum(numbers) |
| |
| @task |
| def b_double(summed: int) -> int: |
| return summed * 2 |
| |
| @task |
| def c_gather(result) -> None: |
| pass |
| |
| static = EmptyOperator(task_id="static") |
| |
| summed = a_sum.expand(numbers=generate()) |
| doubled = b_double.expand(summed=summed) |
| static >> c_gather(doubled) |
| |
| dr: DagRun = dag_maker.create_dagrun() |
| tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances} |
| |
| static_ti = tis[("static", -1)] |
| static_ti.run(session=session) |
| static_ti.schedule_downstream_tasks(session=session) |
| # No tasks should be skipped yet! |
| assert not dr.get_task_instances([TaskInstanceState.SKIPPED], session=session) |
| |
| generate_ti = tis[("generate", -1)] |
| generate_ti.run(session=session) |
| generate_ti.schedule_downstream_tasks(session=session) |
| # Now downstreams can be skipped. |
| assert dr.get_task_instances([TaskInstanceState.SKIPPED], session=session) |
| |
| |
| def test_taskinstance_with_note(create_task_instance, session): |
| ti: TaskInstance = create_task_instance(session=session) |
| ti.note = "ti with note" |
| |
| session.add(ti) |
| session.commit() |
| |
| filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) |
| |
| ti_note: TaskInstanceNote = session.query(TaskInstanceNote).filter_by(**filter_kwargs).one() |
| assert ti_note.content == "ti with note" |
| |
| session.delete(ti) |
| session.commit() |
| |
| assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None |
| assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None |
| |
| |
| def test__refresh_from_db_should_not_increment_try_number(dag_maker, session): |
| with dag_maker(): |
| BashOperator(task_id="hello", bash_command="hi") |
| dag_maker.create_dagrun(state="success") |
| ti = session.scalar(select(TaskInstance)) |
| assert ti.task_id == "hello" # just to confirm... |
| assert ti.try_number == 1 # starts out as 1 |
| ti.refresh_from_db() |
| assert ti.try_number == 1 # stays 1 |
| ti.refresh_from_db() |
| assert ti.try_number == 1 # stays 1 |
| |
| |
| @pytest.mark.parametrize("state", list(TaskInstanceState)) |
| def test_get_private_try_number(state: str): |
| mock_ti = MagicMock() |
| mock_ti.state = state |
| private_try_number = 2 |
| mock_ti._try_number = private_try_number |
| mock_ti.try_number = _get_try_number(task_instance=mock_ti) |
| delattr(mock_ti, "_try_number") |
| assert _get_private_try_number(task_instance=mock_ti) == private_try_number |