| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import datetime |
| import itertools |
| import logging |
| import os |
| import pickle |
| import re |
| import sys |
| import warnings |
| import weakref |
| from contextlib import redirect_stdout |
| from datetime import timedelta |
| from io import StringIO |
| from pathlib import Path |
| from typing import TYPE_CHECKING |
| from unittest import mock |
| from unittest.mock import patch |
| |
| import jinja2 |
| import pendulum |
| import pytest |
| import time_machine |
| from dateutil.relativedelta import relativedelta |
| from pendulum.tz.timezone import Timezone |
| from sqlalchemy import inspect, select |
| from sqlalchemy.exc import SAWarning |
| |
| from airflow import settings |
| from airflow.configuration import conf |
| from airflow.datasets import Dataset, DatasetAll, DatasetAny |
| from airflow.decorators import setup, task as task_decorator, teardown |
| from airflow.exceptions import ( |
| AirflowException, |
| DuplicateTaskIdFound, |
| ParamValidationError, |
| RemovedInAirflow3Warning, |
| ) |
| from airflow.models.baseoperator import BaseOperator |
| from airflow.models.dag import ( |
| DAG, |
| DagModel, |
| DagOwnerAttributes, |
| DagTag, |
| dag as dag_decorator, |
| get_dataset_triggered_next_run_info, |
| ) |
| from airflow.models.dagrun import DagRun |
| from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel, TaskOutletDatasetReference |
| from airflow.models.param import DagParam, Param, ParamsDict |
| from airflow.models.serialized_dag import SerializedDagModel |
| from airflow.models.taskfail import TaskFail |
| from airflow.models.taskinstance import TaskInstance as TI |
| from airflow.operators.bash import BashOperator |
| from airflow.operators.empty import EmptyOperator |
| from airflow.operators.python import PythonOperator |
| from airflow.operators.subdag import SubDagOperator |
| from airflow.security import permissions |
| from airflow.templates import NativeEnvironment, SandboxedEnvironment |
| from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable |
| from airflow.timetables.simple import ( |
| ContinuousTimetable, |
| DatasetTriggeredTimetable, |
| NullTimetable, |
| OnceTimetable, |
| ) |
| from airflow.utils import timezone |
| from airflow.utils.file import list_py_file_paths |
| from airflow.utils.session import create_session, provide_session |
| from airflow.utils.state import DagRunState, State, TaskInstanceState |
| from airflow.utils.task_group import TaskGroup, TaskGroupContext |
| from airflow.utils.timezone import datetime as datetime_tz |
| from airflow.utils.trigger_rule import TriggerRule |
| from airflow.utils.types import DagRunType |
| from airflow.utils.weight_rule import WeightRule |
| from tests.models import DEFAULT_DATE |
| from tests.plugins.priority_weight_strategy import ( |
| FactorPriorityWeightStrategy, |
| NotRegisteredPriorityWeightStrategy, |
| StaticTestPriorityWeightStrategy, |
| TestPriorityWeightStrategyPlugin, |
| ) |
| from tests.test_utils.asserts import assert_queries_count |
| from tests.test_utils.config import conf_vars |
| from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, clear_db_serialized_dags |
| from tests.test_utils.mapping import expand_mapped_task |
| from tests.test_utils.mock_plugins import mock_plugin_manager |
| from tests.test_utils.timetables import cron_timetable, delta_timetable |
| |
| if TYPE_CHECKING: |
| from sqlalchemy.orm import Session |
| |
| pytestmark = pytest.mark.db_test |
| |
| TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) |
| |
| repo_root = Path(__file__).parents[2] |
| |
| |
| @pytest.fixture |
| def clear_dags(): |
| clear_db_dags() |
| clear_db_serialized_dags() |
| yield |
| clear_db_dags() |
| clear_db_serialized_dags() |
| |
| |
| @pytest.fixture |
| def clear_datasets(): |
| clear_db_datasets() |
| yield |
| clear_db_datasets() |
| |
| |
| class TestDag: |
| def setup_method(self) -> None: |
| clear_db_runs() |
| clear_db_dags() |
| clear_db_datasets() |
| self.patcher_dag_code = mock.patch("airflow.models.dag.DagCode.bulk_sync_to_db") |
| self.patcher_dag_code.start() |
| |
| def teardown_method(self) -> None: |
| clear_db_runs() |
| clear_db_dags() |
| clear_db_datasets() |
| self.patcher_dag_code.stop() |
| |
| @staticmethod |
| def _clean_up(dag_id: str): |
| with create_session() as session: |
| session.query(DagRun).filter(DagRun.dag_id == dag_id).delete(synchronize_session=False) |
| session.query(TI).filter(TI.dag_id == dag_id).delete(synchronize_session=False) |
| session.query(TaskFail).filter(TaskFail.dag_id == dag_id).delete(synchronize_session=False) |
| |
| @staticmethod |
| def _occur_before(a, b, list_): |
| """ |
| Assert that a occurs before b in the list. |
| """ |
| a_index = -1 |
| b_index = -1 |
| for i, e in enumerate(list_): |
| if e.task_id == a: |
| a_index = i |
| if e.task_id == b: |
| b_index = i |
| return 0 <= a_index < b_index |
| |
| def test_params_not_passed_is_empty_dict(self): |
| """ |
| Test that when 'params' is _not_ passed to a new Dag, that the params |
| attribute is set to an empty dictionary. |
| """ |
| dag = DAG("test-dag") |
| |
| assert isinstance(dag.params, ParamsDict) |
| assert 0 == len(dag.params) |
| |
| def test_params_passed_and_params_in_default_args_no_override(self): |
| """ |
| Test that when 'params' exists as a key passed to the default_args dict |
| in addition to params being passed explicitly as an argument to the |
| dag, that the 'params' key of the default_args dict is merged with the |
| dict of the params argument. |
| """ |
| params1 = {"parameter1": 1} |
| params2 = {"parameter2": 2} |
| |
| dag = DAG("test-dag", default_args={"params": params1}, params=params2) |
| |
| assert params1["parameter1"] == dag.params["parameter1"] |
| assert params2["parameter2"] == dag.params["parameter2"] |
| |
| def test_not_none_schedule_with_non_default_params(self): |
| """ |
| Test if there is a DAG with not None schedule_interval and have some params that |
| don't have a default value raise a error while DAG parsing |
| """ |
| params = {"param1": Param(type="string")} |
| |
| with pytest.raises(AirflowException): |
| DAG("dummy-dag", params=params) |
| |
| def test_dag_invalid_default_view(self): |
| """ |
| Test invalid `default_view` of DAG initialization |
| """ |
| with pytest.raises(AirflowException, match="Invalid values of dag.default_view: only support"): |
| DAG(dag_id="test-invalid-default_view", default_view="airflow") |
| |
| def test_dag_default_view_default_value(self): |
| """ |
| Test `default_view` default value of DAG initialization |
| """ |
| dag = DAG(dag_id="test-default_default_view") |
| assert conf.get("webserver", "dag_default_view").lower() == dag.default_view |
| |
| def test_dag_invalid_orientation(self): |
| """ |
| Test invalid `orientation` of DAG initialization |
| """ |
| with pytest.raises(AirflowException, match="Invalid values of dag.orientation: only support"): |
| DAG(dag_id="test-invalid-orientation", orientation="airflow") |
| |
| def test_dag_orientation_default_value(self): |
| """ |
| Test `orientation` default value of DAG initialization |
| """ |
| dag = DAG(dag_id="test-default_orientation") |
| assert conf.get("webserver", "dag_orientation") == dag.orientation |
| |
| def test_dag_as_context_manager(self): |
| """ |
| Test DAG as a context manager. |
| When used as a context manager, Operators are automatically added to |
| the DAG (unless they specify a different DAG) |
| """ |
| dag = DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) |
| dag2 = DAG("dag2", start_date=DEFAULT_DATE, default_args={"owner": "owner2"}) |
| |
| with dag: |
| op1 = EmptyOperator(task_id="op1") |
| op2 = EmptyOperator(task_id="op2", dag=dag2) |
| |
| assert op1.dag is dag |
| assert op1.owner == "owner1" |
| assert op2.dag is dag2 |
| assert op2.owner == "owner2" |
| |
| with dag2: |
| op3 = EmptyOperator(task_id="op3") |
| |
| assert op3.dag is dag2 |
| assert op3.owner == "owner2" |
| |
| with dag: |
| with dag2: |
| op4 = EmptyOperator(task_id="op4") |
| op5 = EmptyOperator(task_id="op5") |
| |
| assert op4.dag is dag2 |
| assert op5.dag is dag |
| assert op4.owner == "owner2" |
| assert op5.owner == "owner1" |
| |
| with DAG("creating_dag_in_cm", start_date=DEFAULT_DATE) as dag: |
| EmptyOperator(task_id="op6") |
| |
| assert dag.dag_id == "creating_dag_in_cm" |
| assert dag.tasks[0].task_id == "op6" |
| |
| with dag: |
| with dag: |
| op7 = EmptyOperator(task_id="op7") |
| op8 = EmptyOperator(task_id="op8") |
| op9 = EmptyOperator(task_id="op8") |
| op9.dag = dag2 |
| |
| assert op7.dag == dag |
| assert op8.dag == dag |
| assert op9.dag == dag2 |
| |
| def test_dag_topological_sort_include_subdag_tasks(self): |
| child_dag = DAG( |
| "parent_dag.child_dag", |
| schedule="@daily", |
| start_date=DEFAULT_DATE, |
| ) |
| |
| with child_dag: |
| EmptyOperator(task_id="a_child") |
| EmptyOperator(task_id="b_child") |
| |
| parent_dag = DAG( |
| "parent_dag", |
| schedule="@daily", |
| start_date=DEFAULT_DATE, |
| ) |
| |
| # a_parent -> child_dag -> (a_child | b_child) -> b_parent |
| with parent_dag: |
| op1 = EmptyOperator(task_id="a_parent") |
| op2 = SubDagOperator(task_id="child_dag", subdag=child_dag) |
| op3 = EmptyOperator(task_id="b_parent") |
| |
| op1 >> op2 >> op3 |
| |
| topological_list = parent_dag.topological_sort(include_subdag_tasks=True) |
| |
| assert self._occur_before("a_parent", "child_dag", topological_list) |
| assert self._occur_before("child_dag", "a_child", topological_list) |
| assert self._occur_before("child_dag", "b_child", topological_list) |
| assert self._occur_before("a_child", "b_parent", topological_list) |
| assert self._occur_before("b_child", "b_parent", topological_list) |
| |
| def test_dag_topological_sort_dag_without_tasks(self): |
| dag = DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) |
| |
| assert () == dag.topological_sort() |
| |
| def test_dag_naive_start_date_string(self): |
| DAG("DAG", default_args={"start_date": "2019-06-01"}) |
| |
| def test_dag_naive_start_end_dates_strings(self): |
| DAG("DAG", default_args={"start_date": "2019-06-01", "end_date": "2019-06-05"}) |
| |
| def test_dag_start_date_propagates_to_end_date(self): |
| """ |
| Tests that a start_date string with a timezone and an end_date string without a timezone |
| are accepted and that the timezone from the start carries over the end |
| |
| This test is a little indirect, it works by setting start and end equal except for the |
| timezone and then testing for equality after the DAG construction. They'll be equal |
| only if the same timezone was applied to both. |
| |
| An explicit check the `tzinfo` attributes for both are the same is an extra check. |
| """ |
| dag = DAG( |
| "DAG", default_args={"start_date": "2019-06-05T00:00:00+05:00", "end_date": "2019-06-05T00:00:00"} |
| ) |
| assert dag.default_args["start_date"] == dag.default_args["end_date"] |
| assert dag.default_args["start_date"].tzinfo == dag.default_args["end_date"].tzinfo |
| |
| def test_dag_naive_default_args_start_date(self): |
| dag = DAG("DAG", default_args={"start_date": datetime.datetime(2018, 1, 1)}) |
| assert dag.timezone == settings.TIMEZONE |
| dag = DAG("DAG", start_date=datetime.datetime(2018, 1, 1)) |
| assert dag.timezone == settings.TIMEZONE |
| |
| def test_dag_none_default_args_start_date(self): |
| """ |
| Tests if a start_date of None in default_args |
| works. |
| """ |
| dag = DAG("DAG", default_args={"start_date": None}) |
| assert dag.timezone == settings.TIMEZONE |
| |
| def test_dag_task_priority_weight_total(self): |
| width = 5 |
| depth = 5 |
| weight = 5 |
| pattern = re.compile("stage(\\d*).(\\d*)") |
| # Fully connected parallel tasks. i.e. every task at each parallel |
| # stage is dependent on every task in the previous stage. |
| # Default weight should be calculated using downstream descendants |
| with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag: |
| pipeline = [ |
| [EmptyOperator(task_id=f"stage{i}.{j}", priority_weight=weight) for j in range(width)] |
| for i in range(depth) |
| ] |
| for upstream, downstream in zip(pipeline, pipeline[1:]): |
| for up_task, down_task in itertools.product(upstream, downstream): |
| down_task.set_upstream(up_task) |
| |
| for task in dag.task_dict.values(): |
| match = pattern.match(task.task_id) |
| task_depth = int(match.group(1)) |
| # the sum of each stages after this task + itself |
| correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight |
| |
| calculated_weight = task.priority_weight_total |
| assert calculated_weight == correct_weight |
| |
| def test_dag_task_priority_weight_total_using_upstream(self): |
| # Same test as above except use 'upstream' for weight calculation |
| weight = 3 |
| width = 5 |
| depth = 5 |
| pattern = re.compile("stage(\\d*).(\\d*)") |
| with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag: |
| pipeline = [ |
| [ |
| EmptyOperator( |
| task_id=f"stage{i}.{j}", |
| priority_weight=weight, |
| weight_rule=WeightRule.UPSTREAM, |
| ) |
| for j in range(width) |
| ] |
| for i in range(depth) |
| ] |
| for upstream, downstream in zip(pipeline, pipeline[1:]): |
| for up_task, down_task in itertools.product(upstream, downstream): |
| down_task.set_upstream(up_task) |
| |
| for task in dag.task_dict.values(): |
| match = pattern.match(task.task_id) |
| task_depth = int(match.group(1)) |
| # the sum of each stages after this task + itself |
| correct_weight = (task_depth * width + 1) * weight |
| |
| calculated_weight = task.priority_weight_total |
| assert calculated_weight == correct_weight |
| |
| def test_dag_task_priority_weight_total_using_absolute(self): |
| # Same test as above except use 'absolute' for weight calculation |
| weight = 10 |
| width = 5 |
| depth = 5 |
| with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag: |
| pipeline = [ |
| [ |
| EmptyOperator( |
| task_id=f"stage{i}.{j}", |
| priority_weight=weight, |
| weight_rule=WeightRule.ABSOLUTE, |
| ) |
| for j in range(width) |
| ] |
| for i in range(depth) |
| ] |
| for upstream, downstream in zip(pipeline, pipeline[1:]): |
| for up_task, down_task in itertools.product(upstream, downstream): |
| down_task.set_upstream(up_task) |
| |
| for task in dag.task_dict.values(): |
| # the sum of each stages after this task + itself |
| correct_weight = weight |
| calculated_weight = task.priority_weight_total |
| assert calculated_weight == correct_weight |
| |
| def test_dag_task_invalid_weight_rule(self): |
| # Test if we enter an invalid weight rule |
| with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}): |
| with pytest.raises(AirflowException): |
| EmptyOperator(task_id="should_fail", weight_rule="no rule") |
| |
| @pytest.mark.parametrize( |
| "cls, expected", |
| [ |
| (StaticTestPriorityWeightStrategy, 99), |
| (FactorPriorityWeightStrategy, 3), |
| ], |
| ) |
| def test_dag_task_custom_weight_strategy(self, cls, expected): |
| with mock_plugin_manager(plugins=[TestPriorityWeightStrategyPlugin]), DAG( |
| "dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"} |
| ) as dag: |
| task = EmptyOperator( |
| task_id="empty_task", |
| weight_rule=cls(), |
| ) |
| dr = dag.create_dagrun(state=None, run_id="test", execution_date=DEFAULT_DATE) |
| ti = dr.get_task_instance(task.task_id) |
| assert ti.priority_weight == expected |
| |
| def test_dag_task_not_registered_weight_strategy(self): |
| with mock_plugin_manager(plugins=[TestPriorityWeightStrategyPlugin]), DAG( |
| "dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"} |
| ): |
| with pytest.raises(AirflowException, match="Unknown priority strategy"): |
| EmptyOperator( |
| task_id="empty_task", |
| weight_rule=NotRegisteredPriorityWeightStrategy(), |
| ) |
| |
| def test_get_num_task_instances(self): |
| test_dag_id = "test_get_num_task_instances_dag" |
| test_task_id = "task_1" |
| |
| test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE) |
| test_task = EmptyOperator(task_id=test_task_id, dag=test_dag) |
| |
| dr1 = test_dag.create_dagrun(state=None, run_id="test1", execution_date=DEFAULT_DATE) |
| dr2 = test_dag.create_dagrun( |
| state=None, run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(days=1) |
| ) |
| dr3 = test_dag.create_dagrun( |
| state=None, run_id="test3", execution_date=DEFAULT_DATE + datetime.timedelta(days=2) |
| ) |
| dr4 = test_dag.create_dagrun( |
| state=None, run_id="test4", execution_date=DEFAULT_DATE + datetime.timedelta(days=3) |
| ) |
| |
| ti1 = TI(task=test_task, run_id=dr1.run_id) |
| ti1.state = None |
| ti2 = TI(task=test_task, run_id=dr2.run_id) |
| ti2.state = State.RUNNING |
| ti3 = TI(task=test_task, run_id=dr3.run_id) |
| ti3.state = State.QUEUED |
| ti4 = TI(task=test_task, run_id=dr4.run_id) |
| ti4.state = State.RUNNING |
| session = settings.Session() |
| session.merge(ti1) |
| session.merge(ti2) |
| session.merge(ti3) |
| session.merge(ti4) |
| session.commit() |
| |
| assert 0 == DAG.get_num_task_instances(test_dag_id, task_ids=["fakename"], session=session) |
| assert 4 == DAG.get_num_task_instances(test_dag_id, task_ids=[test_task_id], session=session) |
| assert 4 == DAG.get_num_task_instances( |
| test_dag_id, task_ids=["fakename", test_task_id], session=session |
| ) |
| assert 1 == DAG.get_num_task_instances( |
| test_dag_id, task_ids=[test_task_id], states=[None], session=session |
| ) |
| assert 2 == DAG.get_num_task_instances( |
| test_dag_id, task_ids=[test_task_id], states=[State.RUNNING], session=session |
| ) |
| assert 3 == DAG.get_num_task_instances( |
| test_dag_id, task_ids=[test_task_id], states=[None, State.RUNNING], session=session |
| ) |
| assert 4 == DAG.get_num_task_instances( |
| test_dag_id, task_ids=[test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session |
| ) |
| session.close() |
| |
| def test_get_task_instances_before(self): |
| BASE_DATE = timezone.datetime(2022, 7, 20, 20) |
| |
| test_dag_id = "test_get_task_instances_before" |
| test_task_id = "the_task" |
| |
| test_dag = DAG(dag_id=test_dag_id, start_date=BASE_DATE) |
| EmptyOperator(task_id=test_task_id, dag=test_dag) |
| |
| session = settings.Session() |
| |
| def dag_run_before(delta_h=0, type=DagRunType.SCHEDULED): |
| dagrun = test_dag.create_dagrun( |
| state=State.SUCCESS, run_type=type, run_id=f"test_{delta_h}", session=session |
| ) |
| dagrun.start_date = BASE_DATE + timedelta(hours=delta_h) |
| dagrun.execution_date = BASE_DATE + timedelta(hours=delta_h) |
| return dagrun |
| |
| dr1 = dag_run_before(delta_h=-1, type=DagRunType.MANUAL) # H19 |
| dr2 = dag_run_before(delta_h=-2, type=DagRunType.MANUAL) # H18 |
| dr3 = dag_run_before(delta_h=-3, type=DagRunType.MANUAL) # H17 |
| dr4 = dag_run_before(delta_h=-4, type=DagRunType.MANUAL) # H16 |
| dr5 = dag_run_before(delta_h=-5) # H15 |
| dr6 = dag_run_before(delta_h=-6) # H14 |
| dr7 = dag_run_before(delta_h=-7) # H13 |
| dr8 = dag_run_before(delta_h=-8) # H12 |
| |
| session.commit() |
| |
| REF_DATE = BASE_DATE |
| |
| assert set([dr.run_id for dr in [dr1]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr1, dr2, dr3]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=3, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=5, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=7, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7, dr8]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=9, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7, dr8]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=10, session=session) |
| ] |
| ) # stays constrained to available ones |
| |
| REF_DATE = BASE_DATE + timedelta(hours=-3.5) |
| |
| assert set([dr.run_id for dr in [dr4]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr4, dr5, dr6]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=3, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr4, dr5, dr6, dr7, dr8]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=5, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr4, dr5, dr6, dr7, dr8]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=6, session=session) |
| ] |
| ) # stays constrained to available ones |
| |
| REF_DATE = BASE_DATE + timedelta(hours=-8) |
| |
| assert set([dr.run_id for dr in [dr8]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=0, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr8]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session) |
| ] |
| ) |
| assert set([dr.run_id for dr in [dr8]]) == set( |
| [ |
| ti.run_id |
| for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=10, session=session) |
| ] |
| ) |
| |
| session.close() |
| |
| def test_user_defined_filters_macros(self): |
| def jinja_udf(name): |
| return f"Hello {name}" |
| |
| dag = DAG( |
| "test-dag", |
| start_date=DEFAULT_DATE, |
| user_defined_filters={"hello": jinja_udf}, |
| user_defined_macros={"foo": "bar"}, |
| ) |
| jinja_env = dag.get_template_env() |
| |
| assert "hello" in jinja_env.filters |
| assert jinja_env.filters["hello"] == jinja_udf |
| assert jinja_env.globals["foo"] == "bar" |
| |
| def test_set_jinja_env_additional_option(self): |
| dag = DAG("test-dag", jinja_environment_kwargs={"keep_trailing_newline": True, "cache_size": 50}) |
| jinja_env = dag.get_template_env() |
| assert jinja_env.keep_trailing_newline is True |
| assert jinja_env.cache.capacity == 50 |
| |
| assert jinja_env.undefined is jinja2.StrictUndefined |
| |
| def test_template_undefined(self): |
| dag = DAG("test-dag", template_undefined=jinja2.Undefined) |
| jinja_env = dag.get_template_env() |
| assert jinja_env.undefined is jinja2.Undefined |
| |
| @pytest.mark.parametrize( |
| "use_native_obj, force_sandboxed, expected_env", |
| [ |
| (False, True, SandboxedEnvironment), |
| (False, False, SandboxedEnvironment), |
| (True, False, NativeEnvironment), |
| (True, True, SandboxedEnvironment), |
| ], |
| ) |
| def test_template_env(self, use_native_obj, force_sandboxed, expected_env): |
| dag = DAG("test-dag", render_template_as_native_obj=use_native_obj) |
| jinja_env = dag.get_template_env(force_sandboxed=force_sandboxed) |
| assert isinstance(jinja_env, expected_env) |
| |
| def test_resolve_template_files_value(self, tmp_path): |
| path = tmp_path / "testfile.template" |
| path.write_text("{{ ds }}") |
| |
| with DAG("test-dag", start_date=DEFAULT_DATE, template_searchpath=os.fspath(path.parent)): |
| task = EmptyOperator(task_id="op1") |
| |
| task.test_field = path.name |
| task.template_fields = ("test_field",) |
| task.template_ext = (".template",) |
| task.resolve_template_files() |
| |
| assert task.test_field == "{{ ds }}" |
| |
| def test_resolve_template_files_list(self, tmp_path): |
| path = tmp_path / "testfile.template" |
| path.write_text("{{ ds }}") |
| |
| with DAG("test-dag", start_date=DEFAULT_DATE, template_searchpath=os.fspath(path.parent)): |
| task = EmptyOperator(task_id="op1") |
| |
| task.test_field = [path.name, "some_string"] |
| task.template_fields = ("test_field",) |
| task.template_ext = (".template",) |
| task.resolve_template_files() |
| |
| assert task.test_field == ["{{ ds }}", "some_string"] |
| |
| def test_following_previous_schedule(self): |
| """ |
| Make sure DST transitions are properly observed |
| """ |
| local_tz = Timezone("Europe/Zurich") |
| start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55, fold=0)) |
| assert start.isoformat() == "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST" |
| |
| utc = timezone.convert_to_utc(start) |
| assert utc.isoformat() == "2018-10-28T00:55:00+00:00", "Pre-condition: correct DST->UTC conversion" |
| |
| dag = DAG("tz_dag", start_date=start, schedule="*/5 * * * *") |
| _next = dag.following_schedule(utc) |
| next_local = local_tz.convert(_next) |
| |
| assert _next.isoformat() == "2018-10-28T01:00:00+00:00" |
| assert next_local.isoformat() == "2018-10-28T02:00:00+01:00" |
| |
| prev = dag.previous_schedule(utc) |
| prev_local = local_tz.convert(prev) |
| |
| assert prev_local.isoformat() == "2018-10-28T02:50:00+02:00" |
| |
| prev = dag.previous_schedule(_next) |
| prev_local = local_tz.convert(prev) |
| |
| assert prev_local.isoformat() == "2018-10-28T02:55:00+02:00" |
| assert prev == utc |
| |
| def test_following_previous_schedule_daily_dag_cest_to_cet(self): |
| """ |
| Make sure DST transitions are properly observed |
| """ |
| local_tz = pendulum.timezone("Europe/Zurich") |
| start = local_tz.convert(datetime.datetime(2018, 10, 27, 3, fold=0)) |
| |
| utc = timezone.convert_to_utc(start) |
| |
| dag = DAG("tz_dag", start_date=start, schedule="0 3 * * *") |
| |
| prev = dag.previous_schedule(utc) |
| prev_local = local_tz.convert(prev) |
| |
| assert prev_local.isoformat() == "2018-10-26T03:00:00+02:00" |
| assert prev.isoformat() == "2018-10-26T01:00:00+00:00" |
| |
| _next = dag.following_schedule(utc) |
| next_local = local_tz.convert(_next) |
| |
| assert next_local.isoformat() == "2018-10-28T03:00:00+01:00" |
| assert _next.isoformat() == "2018-10-28T02:00:00+00:00" |
| |
| prev = dag.previous_schedule(_next) |
| prev_local = local_tz.convert(prev) |
| |
| assert prev_local.isoformat() == "2018-10-27T03:00:00+02:00" |
| assert prev.isoformat() == "2018-10-27T01:00:00+00:00" |
| |
| def test_following_previous_schedule_daily_dag_cet_to_cest(self): |
| """ |
| Make sure DST transitions are properly observed |
| """ |
| local_tz = pendulum.timezone("Europe/Zurich") |
| start = local_tz.convert(datetime.datetime(2018, 3, 25, 2, fold=0)) |
| |
| utc = timezone.convert_to_utc(start) |
| |
| dag = DAG("tz_dag", start_date=start, schedule="0 3 * * *") |
| |
| prev = dag.previous_schedule(utc) |
| prev_local = local_tz.convert(prev) |
| |
| assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00" |
| assert prev.isoformat() == "2018-03-24T02:00:00+00:00" |
| |
| _next = dag.following_schedule(utc) |
| next_local = local_tz.convert(_next) |
| |
| assert next_local.isoformat() == "2018-03-25T03:00:00+02:00" |
| assert _next.isoformat() == "2018-03-25T01:00:00+00:00" |
| |
| prev = dag.previous_schedule(_next) |
| prev_local = local_tz.convert(prev) |
| |
| assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00" |
| assert prev.isoformat() == "2018-03-24T02:00:00+00:00" |
| |
| def test_following_schedule_relativedelta(self): |
| """ |
| Tests following_schedule a dag with a relativedelta schedule |
| """ |
| dag_id = "test_schedule_dag_relativedelta" |
| delta = relativedelta(hours=+1) |
| dag = DAG(dag_id=dag_id, schedule=delta, start_date=TEST_DATE) |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) |
| |
| _next = dag.following_schedule(TEST_DATE) |
| assert _next.isoformat() == "2015-01-02T01:00:00+00:00" |
| |
| _next = dag.following_schedule(_next) |
| assert _next.isoformat() == "2015-01-02T02:00:00+00:00" |
| |
| def test_following_schedule_relativedelta_with_deprecated_schedule_interval(self): |
| """ |
| Tests following_schedule a dag with a relativedelta schedule_interval |
| """ |
| dag_id = "test_schedule_dag_relativedelta" |
| delta = relativedelta(hours=+1) |
| dag = DAG(dag_id=dag_id, schedule_interval=delta, start_date=TEST_DATE) |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) |
| |
| _next = dag.following_schedule(TEST_DATE) |
| assert _next.isoformat() == "2015-01-02T01:00:00+00:00" |
| |
| _next = dag.following_schedule(_next) |
| assert _next.isoformat() == "2015-01-02T02:00:00+00:00" |
| |
| def test_following_schedule_relativedelta_with_depr_schedule_interval_decorated_dag(self): |
| """ |
| Tests following_schedule a dag with a relativedelta schedule_interval |
| using decorated dag |
| """ |
| from airflow.decorators import dag |
| |
| dag_id = "test_schedule_dag_relativedelta" |
| delta = relativedelta(hours=+1) |
| |
| @dag(dag_id=dag_id, schedule_interval=delta, start_date=TEST_DATE) |
| def mydag(): |
| BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE) |
| |
| _dag = mydag() |
| |
| _next = _dag.following_schedule(TEST_DATE) |
| assert _next.isoformat() == "2015-01-02T01:00:00+00:00" |
| |
| _next = _dag.following_schedule(_next) |
| assert _next.isoformat() == "2015-01-02T02:00:00+00:00" |
| |
| def test_previous_schedule_datetime_timezone(self): |
| # Check that we don't get an AttributeError 'name' for self.timezone |
| |
| start = datetime.datetime(2018, 3, 25, 2, tzinfo=datetime.timezone.utc) |
| dag = DAG("tz_dag", start_date=start, schedule="@hourly") |
| when = dag.previous_schedule(start) |
| assert when.isoformat() == "2018-03-25T01:00:00+00:00" |
| |
| def test_following_schedule_datetime_timezone(self): |
| # Check that we don't get an AttributeError 'name' for self.timezone |
| |
| start = datetime.datetime(2018, 3, 25, 2, tzinfo=datetime.timezone.utc) |
| dag = DAG("tz_dag", start_date=start, schedule="@hourly") |
| when = dag.following_schedule(start) |
| assert when.isoformat() == "2018-03-25T03:00:00+00:00" |
| |
| def test_create_dagrun_when_schedule_is_none_and_empty_start_date(self): |
| # Check that we don't get an AttributeError 'start_date' for self.start_date when schedule is none |
| dag = DAG("dag_with_none_schedule_and_empty_start_date") |
| dag.add_task(BaseOperator(task_id="task_without_start_date")) |
| dagrun = dag.create_dagrun( |
| state=State.RUNNING, run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE |
| ) |
| assert dagrun is not None |
| |
| def test_fail_dag_when_schedule_is_non_none_and_empty_start_date(self): |
| # Check that we get a ValueError 'start_date' for self.start_date when schedule is non-none |
| with pytest.raises(ValueError, match="DAG is missing the start_date parameter"): |
| DAG(dag_id="dag_with_non_none_schedule_and_empty_start_date", schedule="@hourly") |
| |
| def test_following_schedule_datetime_timezone_utc0530(self): |
| # Check that we don't get an AttributeError 'name' for self.timezone |
| class UTC0530(datetime.tzinfo): |
| """tzinfo derived concrete class named "+0530" with offset of 19800""" |
| |
| # can be configured here |
| _offset = datetime.timedelta(seconds=19800) |
| _dst = datetime.timedelta(0) |
| _name = "+0530" |
| |
| def utcoffset(self, dt): |
| return self.__class__._offset |
| |
| def dst(self, dt): |
| return self.__class__._dst |
| |
| def tzname(self, dt): |
| return self.__class__._name |
| |
| start = datetime.datetime(2018, 3, 25, 10, tzinfo=UTC0530()) |
| dag = DAG("tz_dag", start_date=start, schedule="@hourly") |
| when = dag.following_schedule(start) |
| assert when.isoformat() == "2018-03-25T05:30:00+00:00" |
| |
| def test_dagtag_repr(self): |
| clear_db_dags() |
| dag = DAG("dag-test-dagtag", start_date=DEFAULT_DATE, tags=["tag-1", "tag-2"]) |
| dag.sync_to_db() |
| with create_session() as session: |
| assert {"tag-1", "tag-2"} == { |
| repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == "dag-test-dagtag").all() |
| } |
| |
| def test_bulk_write_to_db(self): |
| clear_db_dags() |
| dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(4)] |
| |
| with assert_queries_count(5): |
| DAG.bulk_write_to_db(dags) |
| with create_session() as session: |
| assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { |
| row[0] for row in session.query(DagModel.dag_id).all() |
| } |
| assert { |
| ("dag-bulk-sync-0", "test-dag"), |
| ("dag-bulk-sync-1", "test-dag"), |
| ("dag-bulk-sync-2", "test-dag"), |
| ("dag-bulk-sync-3", "test-dag"), |
| } == set(session.query(DagTag.dag_id, DagTag.name).all()) |
| |
| for row in session.query(DagModel.last_parsed_time).all(): |
| assert row[0] is not None |
| |
| # Re-sync should do fewer queries |
| with assert_queries_count(8): |
| DAG.bulk_write_to_db(dags) |
| with assert_queries_count(8): |
| DAG.bulk_write_to_db(dags) |
| # Adding tags |
| for dag in dags: |
| dag.tags.append("test-dag2") |
| with assert_queries_count(9): |
| DAG.bulk_write_to_db(dags) |
| with create_session() as session: |
| assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { |
| row[0] for row in session.query(DagModel.dag_id).all() |
| } |
| assert { |
| ("dag-bulk-sync-0", "test-dag"), |
| ("dag-bulk-sync-0", "test-dag2"), |
| ("dag-bulk-sync-1", "test-dag"), |
| ("dag-bulk-sync-1", "test-dag2"), |
| ("dag-bulk-sync-2", "test-dag"), |
| ("dag-bulk-sync-2", "test-dag2"), |
| ("dag-bulk-sync-3", "test-dag"), |
| ("dag-bulk-sync-3", "test-dag2"), |
| } == set(session.query(DagTag.dag_id, DagTag.name).all()) |
| # Removing tags |
| for dag in dags: |
| dag.tags.remove("test-dag") |
| with assert_queries_count(9): |
| DAG.bulk_write_to_db(dags) |
| with create_session() as session: |
| assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { |
| row[0] for row in session.query(DagModel.dag_id).all() |
| } |
| assert { |
| ("dag-bulk-sync-0", "test-dag2"), |
| ("dag-bulk-sync-1", "test-dag2"), |
| ("dag-bulk-sync-2", "test-dag2"), |
| ("dag-bulk-sync-3", "test-dag2"), |
| } == set(session.query(DagTag.dag_id, DagTag.name).all()) |
| |
| for row in session.query(DagModel.last_parsed_time).all(): |
| assert row[0] is not None |
| |
| # Removing all tags |
| for dag in dags: |
| dag.tags = None |
| with assert_queries_count(9): |
| DAG.bulk_write_to_db(dags) |
| with create_session() as session: |
| assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { |
| row[0] for row in session.query(DagModel.dag_id).all() |
| } |
| assert not set(session.query(DagTag.dag_id, DagTag.name).all()) |
| |
| for row in session.query(DagModel.last_parsed_time).all(): |
| assert row[0] is not None |
| |
| def test_bulk_write_to_db_single_dag(self): |
| """ |
| Test bulk_write_to_db for a single dag using the index optimized query |
| """ |
| clear_db_dags() |
| dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(1)] |
| |
| with assert_queries_count(5): |
| DAG.bulk_write_to_db(dags) |
| with create_session() as session: |
| assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()} |
| assert { |
| ("dag-bulk-sync-0", "test-dag"), |
| } == set(session.query(DagTag.dag_id, DagTag.name).all()) |
| |
| for row in session.query(DagModel.last_parsed_time).all(): |
| assert row[0] is not None |
| |
| # Re-sync should do fewer queries |
| with assert_queries_count(8): |
| DAG.bulk_write_to_db(dags) |
| with assert_queries_count(8): |
| DAG.bulk_write_to_db(dags) |
| |
| def test_bulk_write_to_db_multiple_dags(self): |
| """ |
| Test bulk_write_to_db for multiple dags which does not use the index optimized query |
| """ |
| clear_db_dags() |
| dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(4)] |
| |
| with assert_queries_count(5): |
| DAG.bulk_write_to_db(dags) |
| with create_session() as session: |
| assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { |
| row[0] for row in session.query(DagModel.dag_id).all() |
| } |
| assert { |
| ("dag-bulk-sync-0", "test-dag"), |
| ("dag-bulk-sync-1", "test-dag"), |
| ("dag-bulk-sync-2", "test-dag"), |
| ("dag-bulk-sync-3", "test-dag"), |
| } == set(session.query(DagTag.dag_id, DagTag.name).all()) |
| |
| for row in session.query(DagModel.last_parsed_time).all(): |
| assert row[0] is not None |
| |
| # Re-sync should do fewer queries |
| with assert_queries_count(8): |
| DAG.bulk_write_to_db(dags) |
| with assert_queries_count(8): |
| DAG.bulk_write_to_db(dags) |
| |
| @pytest.mark.parametrize("interval", [None, "@daily"]) |
| def test_bulk_write_to_db_interval_save_runtime(self, interval): |
| mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags) |
| with mock.patch.object(DagRun, "active_runs_of_dags", mock_active_runs_of_dags): |
| dags_null_timetable = [ |
| DAG("dag-interval-None", schedule_interval=None, start_date=TEST_DATE), |
| DAG("dag-interval-test", schedule_interval=interval, start_date=TEST_DATE), |
| ] |
| DAG.bulk_write_to_db(dags_null_timetable, session=settings.Session()) |
| if interval: |
| mock_active_runs_of_dags.assert_called_once() |
| else: |
| mock_active_runs_of_dags.assert_not_called() |
| |
| @pytest.mark.parametrize("state", [DagRunState.RUNNING, DagRunState.QUEUED]) |
| def test_bulk_write_to_db_max_active_runs(self, state): |
| """ |
| Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max |
| active runs being hit. |
| """ |
| dag = DAG(dag_id="test_scheduler_verify_max_active_runs", start_date=DEFAULT_DATE) |
| dag.max_active_runs = 1 |
| |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| session = settings.Session() |
| dag.clear() |
| DAG.bulk_write_to_db([dag], session=session) |
| |
| model = session.get(DagModel, dag.dag_id) |
| |
| assert model.next_dagrun == DEFAULT_DATE |
| assert model.next_dagrun_create_after == DEFAULT_DATE + timedelta(days=1) |
| |
| dr = dag.create_dagrun( |
| state=state, |
| execution_date=model.next_dagrun, |
| run_type=DagRunType.SCHEDULED, |
| session=session, |
| ) |
| assert dr is not None |
| DAG.bulk_write_to_db([dag]) |
| |
| model = session.get(DagModel, dag.dag_id) |
| # We signal "at max active runs" by saying this run is never eligible to be created |
| assert model.next_dagrun_create_after is None |
| # test that bulk_write_to_db again doesn't update next_dagrun_create_after |
| DAG.bulk_write_to_db([dag]) |
| model = session.get(DagModel, dag.dag_id) |
| assert model.next_dagrun_create_after is None |
| |
| def test_bulk_write_to_db_has_import_error(self): |
| """ |
| Test that DagModel.has_import_error is set to false if no import errors. |
| """ |
| dag = DAG(dag_id="test_has_import_error", start_date=DEFAULT_DATE) |
| |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| session = settings.Session() |
| dag.clear() |
| DAG.bulk_write_to_db([dag], session=session) |
| |
| model = session.get(DagModel, dag.dag_id) |
| |
| assert not model.has_import_errors |
| |
| # Simulate Dagfileprocessor setting the import error to true |
| model.has_import_errors = True |
| session.merge(model) |
| session.flush() |
| model = session.get(DagModel, dag.dag_id) |
| # assert |
| assert model.has_import_errors |
| # parse |
| DAG.bulk_write_to_db([dag]) |
| |
| model = session.get(DagModel, dag.dag_id) |
| # assert that has_import_error is now false |
| assert not model.has_import_errors |
| session.close() |
| |
| def test_bulk_write_to_db_datasets(self): |
| """ |
| Ensure that datasets referenced in a dag are correctly loaded into the database. |
| """ |
| dag_id1 = "test_dataset_dag1" |
| dag_id2 = "test_dataset_dag2" |
| task_id = "test_dataset_task" |
| uri1 = "s3://dataset/1" |
| d1 = Dataset(uri1, extra={"not": "used"}) |
| d2 = Dataset("s3://dataset/2") |
| d3 = Dataset("s3://dataset/3") |
| dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[d1]) |
| EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2, d3]) |
| dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE) |
| EmptyOperator(task_id=task_id, dag=dag2, outlets=[Dataset(uri1, extra={"should": "be used"})]) |
| session = settings.Session() |
| dag1.clear() |
| DAG.bulk_write_to_db([dag1, dag2], session=session) |
| session.commit() |
| stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()} |
| d1_orm = stored_datasets[d1.uri] |
| d2_orm = stored_datasets[d2.uri] |
| d3_orm = stored_datasets[d3.uri] |
| assert stored_datasets[uri1].extra == {"should": "be used"} |
| assert [x.dag_id for x in d1_orm.consuming_dags] == [dag_id1] |
| assert [(x.task_id, x.dag_id) for x in d1_orm.producing_tasks] == [(task_id, dag_id2)] |
| assert set( |
| session.query( |
| TaskOutletDatasetReference.task_id, |
| TaskOutletDatasetReference.dag_id, |
| TaskOutletDatasetReference.dataset_id, |
| ) |
| .filter(TaskOutletDatasetReference.dag_id.in_((dag_id1, dag_id2))) |
| .all() |
| ) == { |
| (task_id, dag_id1, d2_orm.id), |
| (task_id, dag_id1, d3_orm.id), |
| (task_id, dag_id2, d1_orm.id), |
| } |
| |
| # now that we have verified that a new dag has its dataset references recorded properly, |
| # we need to verify that *changes* are recorded properly. |
| # so if any references are *removed*, they should also be deleted from the DB |
| # so let's remove some references and see what happens |
| dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=None) |
| EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2]) |
| dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE) |
| EmptyOperator(task_id=task_id, dag=dag2) |
| DAG.bulk_write_to_db([dag1, dag2], session=session) |
| session.commit() |
| session.expunge_all() |
| stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()} |
| d1_orm = stored_datasets[d1.uri] |
| d2_orm = stored_datasets[d2.uri] |
| assert [x.dag_id for x in d1_orm.consuming_dags] == [] |
| assert set( |
| session.query( |
| TaskOutletDatasetReference.task_id, |
| TaskOutletDatasetReference.dag_id, |
| TaskOutletDatasetReference.dataset_id, |
| ) |
| .filter(TaskOutletDatasetReference.dag_id.in_((dag_id1, dag_id2))) |
| .all() |
| ) == {(task_id, dag_id1, d2_orm.id)} |
| |
| def test_bulk_write_to_db_unorphan_datasets(self): |
| """ |
| Datasets can lose their last reference and be orphaned, but then if a reference to them reappears, we |
| need to un-orphan those datasets |
| """ |
| with create_session() as session: |
| # Create four datasets - two that have references and two that are unreferenced and marked as |
| # orphans |
| dataset1 = Dataset(uri="ds1") |
| dataset2 = Dataset(uri="ds2") |
| session.add(DatasetModel(uri=dataset2.uri, is_orphaned=True)) |
| dataset3 = Dataset(uri="ds3") |
| dataset4 = Dataset(uri="ds4") |
| session.add(DatasetModel(uri=dataset4.uri, is_orphaned=True)) |
| session.flush() |
| |
| dag1 = DAG(dag_id="datasets-1", start_date=DEFAULT_DATE, schedule=[dataset1]) |
| BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[dataset3]) |
| |
| DAG.bulk_write_to_db([dag1], session=session) |
| |
| # Double check |
| non_orphaned_datasets = [ |
| dataset.uri |
| for dataset in session.query(DatasetModel.uri) |
| .filter(~DatasetModel.is_orphaned) |
| .order_by(DatasetModel.uri) |
| ] |
| assert non_orphaned_datasets == ["ds1", "ds3"] |
| orphaned_datasets = [ |
| dataset.uri |
| for dataset in session.query(DatasetModel.uri) |
| .filter(DatasetModel.is_orphaned) |
| .order_by(DatasetModel.uri) |
| ] |
| assert orphaned_datasets == ["ds2", "ds4"] |
| |
| # Now add references to the two unreferenced datasets |
| dag1 = DAG(dag_id="datasets-1", start_date=DEFAULT_DATE, schedule=[dataset1, dataset2]) |
| BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[dataset3, dataset4]) |
| |
| DAG.bulk_write_to_db([dag1], session=session) |
| |
| # and count the orphans and non-orphans |
| non_orphaned_dataset_count = session.query(DatasetModel).filter(~DatasetModel.is_orphaned).count() |
| assert non_orphaned_dataset_count == 4 |
| orphaned_dataset_count = session.query(DatasetModel).filter(DatasetModel.is_orphaned).count() |
| assert orphaned_dataset_count == 0 |
| |
| def test_sync_to_db(self): |
| dag = DAG( |
| "dag", |
| start_date=DEFAULT_DATE, |
| ) |
| with dag: |
| EmptyOperator(task_id="task", owner="owner1") |
| subdag = DAG( |
| "dag.subtask", |
| start_date=DEFAULT_DATE, |
| ) |
| # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set. |
| subdag.parent_dag = dag |
| SubDagOperator(task_id="subtask", owner="owner2", subdag=subdag) |
| session = settings.Session() |
| dag.sync_to_db(session=session) |
| |
| orm_dag = session.query(DagModel).filter(DagModel.dag_id == "dag").one() |
| assert set(orm_dag.owners.split(", ")) == {"owner1", "owner2"} |
| assert orm_dag.is_active |
| assert orm_dag.default_view is not None |
| assert orm_dag.default_view == conf.get("webserver", "dag_default_view").lower() |
| assert orm_dag.safe_dag_id == "dag" |
| |
| orm_subdag = session.query(DagModel).filter(DagModel.dag_id == "dag.subtask").one() |
| assert set(orm_subdag.owners.split(", ")) == {"owner1", "owner2"} |
| assert orm_subdag.is_active |
| assert orm_subdag.safe_dag_id == "dag__dot__subtask" |
| assert orm_subdag.fileloc == orm_dag.fileloc |
| session.close() |
| |
| def test_sync_to_db_default_view(self): |
| dag = DAG( |
| "dag", |
| start_date=DEFAULT_DATE, |
| default_view="graph", |
| ) |
| with dag: |
| EmptyOperator(task_id="task", owner="owner1") |
| SubDagOperator( |
| task_id="subtask", |
| owner="owner2", |
| subdag=DAG( |
| "dag.subtask", |
| start_date=DEFAULT_DATE, |
| ), |
| ) |
| session = settings.Session() |
| dag.sync_to_db(session=session) |
| |
| orm_dag = session.query(DagModel).filter(DagModel.dag_id == "dag").one() |
| assert orm_dag.default_view is not None |
| assert orm_dag.default_view == "graph" |
| session.close() |
| |
| @provide_session |
| def test_is_paused_subdag(self, session): |
| subdag_id = "dag.subdag" |
| subdag = DAG( |
| subdag_id, |
| start_date=DEFAULT_DATE, |
| ) |
| with subdag: |
| EmptyOperator( |
| task_id="dummy_task", |
| ) |
| |
| dag_id = "dag" |
| dag = DAG( |
| dag_id, |
| start_date=DEFAULT_DATE, |
| ) |
| |
| with dag: |
| SubDagOperator(task_id="subdag", subdag=subdag) |
| |
| # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set. |
| subdag.parent_dag = dag |
| |
| session.query(DagModel).filter(DagModel.dag_id.in_([subdag_id, dag_id])).delete( |
| synchronize_session=False |
| ) |
| |
| dag.sync_to_db(session=session) |
| |
| unpaused_dags = ( |
| session.query(DagModel.dag_id, DagModel.is_paused) |
| .filter( |
| DagModel.dag_id.in_([subdag_id, dag_id]), |
| ) |
| .all() |
| ) |
| |
| assert { |
| (dag_id, False), |
| (subdag_id, False), |
| } == set(unpaused_dags) |
| |
| DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True, including_subdags=False) |
| |
| paused_dags = ( |
| session.query(DagModel.dag_id, DagModel.is_paused) |
| .filter( |
| DagModel.dag_id.in_([subdag_id, dag_id]), |
| ) |
| .all() |
| ) |
| |
| assert { |
| (dag_id, True), |
| (subdag_id, False), |
| } == set(paused_dags) |
| |
| DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True) |
| |
| paused_dags = ( |
| session.query(DagModel.dag_id, DagModel.is_paused) |
| .filter( |
| DagModel.dag_id.in_([subdag_id, dag_id]), |
| ) |
| .all() |
| ) |
| |
| assert { |
| (dag_id, True), |
| (subdag_id, True), |
| } == set(paused_dags) |
| |
| def test_existing_dag_is_paused_upon_creation(self): |
| dag = DAG("dag_paused") |
| dag.sync_to_db() |
| assert not dag.get_is_paused() |
| |
| dag = DAG("dag_paused", is_paused_upon_creation=True) |
| dag.sync_to_db() |
| # Since the dag existed before, it should not follow the pause flag upon creation |
| assert not dag.get_is_paused() |
| |
| def test_new_dag_is_paused_upon_creation(self): |
| dag = DAG("new_nonexisting_dag", is_paused_upon_creation=True) |
| session = settings.Session() |
| dag.sync_to_db(session=session) |
| |
| orm_dag = session.query(DagModel).filter(DagModel.dag_id == "new_nonexisting_dag").one() |
| # Since the dag didn't exist before, it should follow the pause flag upon creation |
| assert orm_dag.is_paused |
| session.close() |
| |
| @mock.patch.dict( |
| os.environ, |
| { |
| "AIRFLOW__CORE__MAX_CONSECUTIVE_FAILED_DAG_RUNS_PER_DAG": "4", |
| }, |
| ) |
| def test_existing_dag_is_paused_config(self): |
| # config should be set properly |
| assert conf.getint("core", "max_consecutive_failed_dag_runs_per_dag") == 4 |
| # checking the default value is coming from config |
| dag = DAG("test_dag") |
| assert dag.max_consecutive_failed_dag_runs == 4 |
| # but we can override the value using params |
| dag = DAG("test_dag2", max_consecutive_failed_dag_runs=2) |
| assert dag.max_consecutive_failed_dag_runs == 2 |
| |
| def test_existing_dag_is_paused_after_limit(self): |
| def add_failed_dag_run(id, execution_date): |
| dr = dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| run_id="run_id_" + id, |
| execution_date=execution_date, |
| state=State.FAILED, |
| ) |
| ti_op1 = dr.get_task_instance(task_id=op1.task_id, session=session) |
| ti_op1.set_state(state=TaskInstanceState.FAILED, session=session) |
| dr.update_state(session=session) |
| |
| dag_id = "dag_paused_after_limit" |
| dag = DAG(dag_id, is_paused_upon_creation=False, max_consecutive_failed_dag_runs=2) |
| op1 = BashOperator(task_id="task", bash_command="exit 1;") |
| dag.add_task(op1) |
| session = settings.Session() |
| dag.sync_to_db(session=session) |
| assert not dag.get_is_paused() |
| |
| # dag should be paused after 2 failed dag_runs |
| add_failed_dag_run( |
| "1", |
| TEST_DATE, |
| ) |
| add_failed_dag_run("2", TEST_DATE + timedelta(days=1)) |
| assert dag.get_is_paused() |
| dag.clear() |
| self._clean_up(dag_id) |
| |
| def test_existing_dag_default_view(self): |
| with create_session() as session: |
| session.add(DagModel(dag_id="dag_default_view_old", default_view=None)) |
| session.commit() |
| orm_dag = session.query(DagModel).filter(DagModel.dag_id == "dag_default_view_old").one() |
| assert orm_dag.default_view is None |
| assert orm_dag.get_default_view() == conf.get("webserver", "dag_default_view").lower() |
| |
| def test_dag_is_deactivated_upon_dagfile_deletion(self): |
| dag_id = "old_existing_dag" |
| dag_fileloc = "/usr/local/airflow/dags/non_existing_path.py" |
| dag = DAG( |
| dag_id, |
| is_paused_upon_creation=True, |
| ) |
| dag.fileloc = dag_fileloc |
| session = settings.Session() |
| with mock.patch("airflow.models.dag.DagCode.bulk_sync_to_db"): |
| dag.sync_to_db(session=session, processor_subdir="/usr/local/airflow/dags/") |
| |
| orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one() |
| |
| assert orm_dag.is_active |
| assert orm_dag.fileloc == dag_fileloc |
| |
| DagModel.deactivate_deleted_dags( |
| list_py_file_paths(settings.DAGS_FOLDER), |
| processor_subdir="/usr/local/airflow/dags/", |
| ) |
| |
| orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one() |
| assert not orm_dag.is_active |
| |
| session.execute(DagModel.__table__.delete().where(DagModel.dag_id == dag_id)) |
| session.close() |
| |
| def test_dag_naive_default_args_start_date_with_timezone(self): |
| local_tz = pendulum.timezone("Europe/Zurich") |
| default_args = {"start_date": datetime.datetime(2018, 1, 1, tzinfo=local_tz)} |
| |
| dag = DAG("DAG", default_args=default_args) |
| assert dag.timezone.name == local_tz.name |
| |
| dag = DAG("DAG", default_args=default_args) |
| assert dag.timezone.name == local_tz.name |
| |
| def test_roots(self): |
| """Verify if dag.roots returns the root tasks of a DAG.""" |
| with DAG("test_dag", start_date=DEFAULT_DATE) as dag: |
| op1 = EmptyOperator(task_id="t1") |
| op2 = EmptyOperator(task_id="t2") |
| op3 = EmptyOperator(task_id="t3") |
| op4 = EmptyOperator(task_id="t4") |
| op5 = EmptyOperator(task_id="t5") |
| [op1, op2] >> op3 >> [op4, op5] |
| |
| assert set(dag.roots) == {op1, op2} |
| |
| def test_leaves(self): |
| """Verify if dag.leaves returns the leaf tasks of a DAG.""" |
| with DAG("test_dag", start_date=DEFAULT_DATE) as dag: |
| op1 = EmptyOperator(task_id="t1") |
| op2 = EmptyOperator(task_id="t2") |
| op3 = EmptyOperator(task_id="t3") |
| op4 = EmptyOperator(task_id="t4") |
| op5 = EmptyOperator(task_id="t5") |
| [op1, op2] >> op3 >> [op4, op5] |
| |
| assert set(dag.leaves) == {op4, op5} |
| |
| def test_tree_view(self): |
| """Verify correctness of dag.tree_view().""" |
| with DAG("test_dag", start_date=DEFAULT_DATE) as dag: |
| op1_a = EmptyOperator(task_id="t1_a") |
| op1_b = EmptyOperator(task_id="t1_b") |
| op2 = EmptyOperator(task_id="t2") |
| op3 = EmptyOperator(task_id="t3") |
| op1_b >> op2 |
| op1_a >> op2 >> op3 |
| |
| with redirect_stdout(StringIO()) as stdout: |
| dag.tree_view() |
| stdout = stdout.getvalue() |
| |
| stdout_lines = stdout.splitlines() |
| assert "t1_a" in stdout_lines[0] |
| assert "t2" in stdout_lines[1] |
| assert "t3" in stdout_lines[2] |
| assert "t1_b" in stdout_lines[3] |
| assert dag.get_tree_view() == ( |
| "<Task(EmptyOperator): t1_a>\n" |
| " <Task(EmptyOperator): t2>\n" |
| " <Task(EmptyOperator): t3>\n" |
| "<Task(EmptyOperator): t1_b>\n" |
| " <Task(EmptyOperator): t2>\n" |
| " <Task(EmptyOperator): t3>\n" |
| ) |
| |
| def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): |
| """Verify tasks with Duplicate task_id raises error""" |
| with DAG("test_dag", start_date=DEFAULT_DATE) as dag: |
| op1 = EmptyOperator(task_id="t1") |
| with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): |
| BashOperator(task_id="t1", bash_command="sleep 1") |
| |
| assert dag.task_dict == {op1.task_id: op1} |
| |
| def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self): |
| """Verify tasks with Duplicate task_id raises error""" |
| dag = DAG("test_dag", start_date=DEFAULT_DATE) |
| op1 = EmptyOperator(task_id="t1", dag=dag) |
| with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): |
| EmptyOperator(task_id="t1", dag=dag) |
| |
| assert dag.task_dict == {op1.task_id: op1} |
| |
| def test_duplicate_task_ids_for_same_task_is_allowed(self): |
| """Verify that same tasks with Duplicate task_id do not raise error""" |
| with DAG("test_dag", start_date=DEFAULT_DATE) as dag: |
| op1 = op2 = EmptyOperator(task_id="t1") |
| op3 = EmptyOperator(task_id="t3") |
| op1 >> op3 |
| op2 >> op3 |
| |
| assert op1 == op2 |
| assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3} |
| assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3} |
| |
| def test_partial_subset_updates_all_references_while_deepcopy(self): |
| with DAG("test_dag", start_date=DEFAULT_DATE) as dag: |
| op1 = EmptyOperator(task_id="t1") |
| op2 = EmptyOperator(task_id="t2") |
| op3 = EmptyOperator(task_id="t3") |
| op1 >> op2 |
| op2 >> op3 |
| |
| partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False) |
| assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial) |
| |
| # Copied DAG should not include unused task IDs in used_group_ids |
| assert "t3" not in partial.task_group.used_group_ids |
| |
| def test_partial_subset_taskgroup_join_ids(self): |
| with DAG("test_dag", start_date=DEFAULT_DATE) as dag: |
| start = EmptyOperator(task_id="start") |
| with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group: |
| with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1: |
| EmptyOperator(task_id="t1") |
| with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2: |
| EmptyOperator(task_id="t2") |
| |
| start >> tg1 >> tg2 |
| |
| # Pre-condition checks |
| task = dag.get_task("t2") |
| assert task.task_group.upstream_group_ids == {"tg1"} |
| assert isinstance(task.task_group.parent_group, weakref.ProxyType) |
| assert task.task_group.parent_group == outer_group |
| |
| partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False) |
| copied_task = partial.get_task("t2") |
| assert copied_task.task_group.upstream_group_ids == {"tg1"} |
| assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType) |
| assert copied_task.task_group.parent_group |
| |
| # Make sure we don't affect the original! |
| assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids |
| |
| def test_schedule_dag_no_previous_runs(self): |
| """ |
| Tests scheduling a dag with no previous runs |
| """ |
| dag_id = "test_schedule_dag_no_previous_runs" |
| dag = DAG(dag_id=dag_id) |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) |
| |
| dag_run = dag.create_dagrun( |
| run_type=DagRunType.SCHEDULED, |
| execution_date=TEST_DATE, |
| state=State.RUNNING, |
| ) |
| assert dag_run is not None |
| assert dag.dag_id == dag_run.dag_id |
| assert dag_run.run_id is not None |
| assert "" != dag_run.run_id |
| assert ( |
| TEST_DATE == dag_run.execution_date |
| ), f"dag_run.execution_date did not match expectation: {dag_run.execution_date}" |
| assert State.RUNNING == dag_run.state |
| assert not dag_run.external_trigger |
| dag.clear() |
| self._clean_up(dag_id) |
| |
| @patch("airflow.models.dag.Stats") |
| def test_dag_handle_callback_crash(self, mock_stats): |
| """ |
| Tests avoid crashes from calling dag callbacks exceptions |
| """ |
| dag_id = "test_dag_callback_crash" |
| mock_callback_with_exception = mock.MagicMock() |
| mock_callback_with_exception.side_effect = Exception |
| dag = DAG( |
| dag_id=dag_id, |
| # callback with invalid signature should not cause crashes |
| on_success_callback=lambda: 1, |
| on_failure_callback=mock_callback_with_exception, |
| ) |
| when = TEST_DATE |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=when)) |
| |
| with create_session() as session: |
| dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL, session=session) |
| |
| # should not raise any exception |
| dag.handle_callback(dag_run, success=False) |
| dag.handle_callback(dag_run, success=True) |
| |
| mock_stats.incr.assert_called_with( |
| "dag.callback_exceptions", |
| tags={"dag_id": "test_dag_callback_crash"}, |
| ) |
| |
| dag.clear() |
| self._clean_up(dag_id) |
| |
| def test_dag_handle_callback_with_removed_task(self, dag_maker, session): |
| """ |
| Tests avoid crashes when a removed task is the last one in the list of task instance |
| """ |
| dag_id = "test_dag_callback_with_removed_task" |
| mock_callback = mock.MagicMock() |
| with DAG( |
| dag_id=dag_id, |
| on_success_callback=mock_callback, |
| on_failure_callback=mock_callback, |
| ) as dag: |
| EmptyOperator(task_id="faketastic") |
| task_removed = EmptyOperator(task_id="removed_task") |
| |
| with create_session() as session: |
| dag_run = dag.create_dagrun(State.RUNNING, TEST_DATE, run_type=DagRunType.MANUAL, session=session) |
| dag._remove_task(task_removed.task_id) |
| tis = dag_run.get_task_instances(session=session) |
| tis[-1].state = TaskInstanceState.REMOVED |
| assert dag_run.get_task_instance(task_removed.task_id).state == TaskInstanceState.REMOVED |
| |
| # should not raise any exception |
| dag.handle_callback(dag_run, success=True) |
| dag.handle_callback(dag_run, success=False) |
| |
| dag.clear() |
| self._clean_up(dag_id) |
| |
| def test_next_dagrun_after_fake_scheduled_previous(self): |
| """ |
| Test scheduling a dag where there is a prior DagRun |
| which has the same run_id as the next run should have |
| """ |
| delta = datetime.timedelta(hours=1) |
| dag_id = "test_schedule_dag_fake_scheduled_previous" |
| dag = DAG(dag_id=dag_id, schedule=delta, start_date=DEFAULT_DATE) |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=DEFAULT_DATE)) |
| |
| dag.create_dagrun( |
| run_type=DagRunType.SCHEDULED, |
| execution_date=DEFAULT_DATE, |
| state=State.SUCCESS, |
| external_trigger=True, |
| ) |
| dag.sync_to_db() |
| with create_session() as session: |
| model = session.get(DagModel, dag.dag_id) |
| |
| # Even though there is a run for this date already, it is marked as manual/external, so we should |
| # create a scheduled one anyway! |
| assert model.next_dagrun == DEFAULT_DATE |
| assert model.next_dagrun_create_after == DEFAULT_DATE + delta |
| |
| self._clean_up(dag_id) |
| |
| def test_schedule_dag_once(self): |
| """ |
| Tests scheduling a dag scheduled for @once - should be scheduled the first time |
| it is called, and not scheduled the second. |
| """ |
| dag_id = "test_schedule_dag_once" |
| dag = DAG(dag_id=dag_id, schedule="@once", start_date=TEST_DATE) |
| assert isinstance(dag.timetable, OnceTimetable) |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) |
| |
| # Sync once to create the DagModel |
| dag.sync_to_db() |
| |
| dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=TEST_DATE, state=State.SUCCESS) |
| |
| # Then sync again after creating the dag run -- this should update next_dagrun |
| dag.sync_to_db() |
| with create_session() as session: |
| model = session.get(DagModel, dag.dag_id) |
| |
| assert model.next_dagrun is None |
| assert model.next_dagrun_create_after is None |
| self._clean_up(dag_id) |
| |
| def test_fractional_seconds(self): |
| """ |
| Tests if fractional seconds are stored in the database |
| """ |
| dag_id = "test_fractional_seconds" |
| dag = DAG(dag_id=dag_id, schedule="@once", start_date=TEST_DATE) |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)) |
| |
| start_date = timezone.utcnow() |
| |
| run = dag.create_dagrun( |
| run_id="test_" + start_date.isoformat(), |
| execution_date=start_date, |
| start_date=start_date, |
| state=State.RUNNING, |
| external_trigger=False, |
| ) |
| |
| run.refresh_from_db() |
| |
| assert start_date == run.execution_date, "dag run execution_date loses precision" |
| assert start_date == run.start_date, "dag run start_date loses precision " |
| self._clean_up(dag_id) |
| |
| def test_pickling(self): |
| test_dag_id = "test_pickling" |
| args = {"owner": "airflow", "start_date": DEFAULT_DATE} |
| dag = DAG(test_dag_id, default_args=args) |
| dag_pickle = dag.pickle() |
| assert dag_pickle.pickle.dag_id == dag.dag_id |
| |
| def test_rich_comparison_ops(self): |
| test_dag_id = "test_rich_comparison_ops" |
| |
| class DAGsubclass(DAG): |
| pass |
| |
| args = {"owner": "airflow", "start_date": DEFAULT_DATE} |
| dag = DAG(test_dag_id, default_args=args) |
| |
| dag_eq = DAG(test_dag_id, default_args=args) |
| |
| dag_diff_load_time = DAG(test_dag_id, default_args=args) |
| dag_diff_name = DAG(test_dag_id + "_neq", default_args=args) |
| |
| dag_subclass = DAGsubclass(test_dag_id, default_args=args) |
| dag_subclass_diff_name = DAGsubclass(test_dag_id + "2", default_args=args) |
| |
| for dag_ in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]: |
| dag_.last_loaded = dag.last_loaded |
| |
| # test identity equality |
| assert dag == dag |
| |
| # test dag (in)equality based on _comps |
| assert dag_eq == dag |
| assert dag_diff_name != dag |
| assert dag_diff_load_time != dag |
| |
| # test dag inequality based on type even if _comps happen to match |
| assert dag_subclass != dag |
| |
| # a dag should equal an unpickled version of itself |
| dump = pickle.dumps(dag) |
| assert pickle.loads(dump) == dag |
| |
| # dags are ordered based on dag_id no matter what the type is |
| assert dag < dag_diff_name |
| assert dag > dag_diff_load_time |
| assert dag < dag_subclass_diff_name |
| |
| # greater than should have been created automatically by functools |
| assert dag_diff_name > dag |
| |
| # hashes are non-random and match equality |
| assert hash(dag) == hash(dag) |
| assert hash(dag_eq) == hash(dag) |
| assert hash(dag_diff_name) != hash(dag) |
| assert hash(dag_subclass) != hash(dag) |
| |
| def test_get_paused_dag_ids(self): |
| dag_id = "test_get_paused_dag_ids" |
| dag = DAG(dag_id, is_paused_upon_creation=True) |
| dag.sync_to_db() |
| assert DagModel.get_dagmodel(dag_id) is not None |
| |
| paused_dag_ids = DagModel.get_paused_dag_ids([dag_id]) |
| assert paused_dag_ids == {dag_id} |
| |
| with create_session() as session: |
| session.query(DagModel).filter(DagModel.dag_id == dag_id).delete(synchronize_session=False) |
| |
| @pytest.mark.parametrize( |
| "schedule_interval_arg, expected_timetable, interval_description", |
| [ |
| (None, NullTimetable(), "Never, external triggers only"), |
| ("@daily", cron_timetable("0 0 * * *"), "At 00:00"), |
| ("@weekly", cron_timetable("0 0 * * 0"), "At 00:00, only on Sunday"), |
| ("@monthly", cron_timetable("0 0 1 * *"), "At 00:00, on day 1 of the month"), |
| ("@quarterly", cron_timetable("0 0 1 */3 *"), "At 00:00, on day 1 of the month, every 3 months"), |
| ("@yearly", cron_timetable("0 0 1 1 *"), "At 00:00, on day 1 of the month, only in January"), |
| ("5 0 * 8 *", cron_timetable("5 0 * 8 *"), "At 00:05, only in August"), |
| ("@once", OnceTimetable(), "Once, as soon as possible"), |
| (datetime.timedelta(days=1), delta_timetable(datetime.timedelta(days=1)), ""), |
| ("30 21 * * 5 1", cron_timetable("30 21 * * 5 1"), ""), |
| ], |
| ) |
| def test_timetable_and_description_from_schedule_interval_arg( |
| self, schedule_interval_arg, expected_timetable, interval_description |
| ): |
| dag = DAG("test_schedule_interval_arg", schedule=schedule_interval_arg, start_date=TEST_DATE) |
| assert dag.timetable == expected_timetable |
| assert dag.schedule_interval == schedule_interval_arg |
| assert dag.timetable.description == interval_description |
| |
| def test_timetable_and_description_from_dataset(self): |
| dag = DAG("test_schedule_interval_arg", schedule=[Dataset(uri="hello")], start_date=TEST_DATE) |
| assert dag.timetable == DatasetTriggeredTimetable() |
| assert dag.schedule_interval == "Dataset" |
| assert dag.timetable.description == "Triggered by datasets" |
| |
| def test_schedule_interval_still_works(self): |
| dag = DAG("test_schedule_interval_arg", schedule_interval="*/5 * * * *", start_date=TEST_DATE) |
| assert dag.timetable == cron_timetable("*/5 * * * *") |
| assert dag.schedule_interval == "*/5 * * * *" |
| assert dag.timetable.description == "Every 5 minutes" |
| |
| def test_timetable_still_works(self): |
| dag = DAG("test_schedule_interval_arg", timetable=cron_timetable("*/6 * * * *"), start_date=TEST_DATE) |
| assert dag.timetable == cron_timetable("*/6 * * * *") |
| assert dag.schedule_interval == "*/6 * * * *" |
| assert dag.timetable.description == "Every 6 minutes" |
| |
| @pytest.mark.parametrize( |
| "timetable, expected_description", |
| [ |
| (NullTimetable(), "Never, external triggers only"), |
| (cron_timetable("0 0 * * *"), "At 00:00"), |
| (cron_timetable("@daily"), "At 00:00"), |
| (cron_timetable("0 0 * * 0"), "At 00:00, only on Sunday"), |
| (cron_timetable("@weekly"), "At 00:00, only on Sunday"), |
| (cron_timetable("0 0 1 * *"), "At 00:00, on day 1 of the month"), |
| (cron_timetable("@monthly"), "At 00:00, on day 1 of the month"), |
| (cron_timetable("0 0 1 */3 *"), "At 00:00, on day 1 of the month, every 3 months"), |
| (cron_timetable("@quarterly"), "At 00:00, on day 1 of the month, every 3 months"), |
| (cron_timetable("0 0 1 1 *"), "At 00:00, on day 1 of the month, only in January"), |
| (cron_timetable("@yearly"), "At 00:00, on day 1 of the month, only in January"), |
| (cron_timetable("5 0 * 8 *"), "At 00:05, only in August"), |
| (OnceTimetable(), "Once, as soon as possible"), |
| (delta_timetable(datetime.timedelta(days=1)), ""), |
| (cron_timetable("30 21 * * 5 1"), ""), |
| ], |
| ) |
| def test_description_from_timetable(self, timetable, expected_description): |
| dag = DAG("test_schedule_interval_description", timetable=timetable, start_date=TEST_DATE) |
| assert dag.timetable == timetable |
| assert dag.timetable.description == expected_description |
| |
| def test_create_dagrun_run_id_is_generated(self): |
| dag = DAG(dag_id="run_id_is_generated") |
| dr = dag.create_dagrun(run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, state=State.NONE) |
| assert dr.run_id == f"manual__{DEFAULT_DATE.isoformat()}" |
| |
| def test_create_dagrun_run_type_is_obtained_from_run_id(self): |
| dag = DAG(dag_id="run_type_is_obtained_from_run_id") |
| dr = dag.create_dagrun(run_id="scheduled__", state=State.NONE) |
| assert dr.run_type == DagRunType.SCHEDULED |
| |
| dr = dag.create_dagrun(run_id="custom_is_set_to_manual", state=State.NONE) |
| assert dr.run_type == DagRunType.MANUAL |
| |
| def test_create_dagrun_job_id_is_set(self): |
| job_id = 42 |
| dag = DAG(dag_id="test_create_dagrun_job_id_is_set") |
| dr = dag.create_dagrun( |
| run_id="test_create_dagrun_job_id_is_set", state=State.NONE, creating_job_id=job_id |
| ) |
| assert dr.creating_job_id == job_id |
| |
| def test_dag_add_task_checks_trigger_rule(self): |
| # A non fail stop dag should allow any trigger rule |
| from airflow.exceptions import FailStopDagInvalidTriggerRule |
| from airflow.utils.trigger_rule import TriggerRule |
| |
| task_with_non_default_trigger_rule = EmptyOperator( |
| task_id="task_with_non_default_trigger_rule", trigger_rule=TriggerRule.ALWAYS |
| ) |
| non_fail_stop_dag = DAG( |
| dag_id="test_dag_add_task_checks_trigger_rule", start_date=DEFAULT_DATE, fail_stop=False |
| ) |
| non_fail_stop_dag.add_task(task_with_non_default_trigger_rule) |
| |
| # a fail stop dag should allow default trigger rule |
| from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE |
| |
| fail_stop_dag = DAG( |
| dag_id="test_dag_add_task_checks_trigger_rule", start_date=DEFAULT_DATE, fail_stop=True |
| ) |
| task_with_default_trigger_rule = EmptyOperator( |
| task_id="task_with_default_trigger_rule", trigger_rule=DEFAULT_TRIGGER_RULE |
| ) |
| fail_stop_dag.add_task(task_with_default_trigger_rule) |
| |
| # a fail stop dag should not allow a non-default trigger rule |
| with pytest.raises(FailStopDagInvalidTriggerRule): |
| fail_stop_dag.add_task(task_with_non_default_trigger_rule) |
| |
| def test_dag_add_task_sets_default_task_group(self): |
| dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", start_date=DEFAULT_DATE) |
| task_without_task_group = EmptyOperator(task_id="task_without_group_id") |
| default_task_group = TaskGroupContext.get_current_task_group(dag) |
| dag.add_task(task_without_task_group) |
| assert default_task_group.get_child_by_label("task_without_group_id") == task_without_task_group |
| |
| task_group = TaskGroup(group_id="task_group", dag=dag) |
| task_with_task_group = EmptyOperator(task_id="task_with_task_group", task_group=task_group) |
| dag.add_task(task_with_task_group) |
| assert task_group.get_child_by_label("task_with_task_group") == task_with_task_group |
| assert dag.get_task("task_group.task_with_task_group") == task_with_task_group |
| |
| @pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING]) |
| def test_clear_set_dagrun_state(self, dag_run_state): |
| dag_id = "test_clear_set_dagrun_state" |
| self._clean_up(dag_id) |
| task_id = "t1" |
| dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) |
| t_1 = EmptyOperator(task_id=task_id, dag=dag) |
| |
| session = settings.Session() |
| dagrun_1 = dag.create_dagrun( |
| run_type=DagRunType.BACKFILL_JOB, |
| state=State.FAILED, |
| start_date=DEFAULT_DATE, |
| execution_date=DEFAULT_DATE, |
| ) |
| session.merge(dagrun_1) |
| |
| task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING) |
| session.merge(task_instance_1) |
| session.commit() |
| |
| dag.clear( |
| start_date=DEFAULT_DATE, |
| end_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| dag_run_state=dag_run_state, |
| include_subdags=False, |
| include_parentdag=False, |
| session=session, |
| ) |
| |
| dagruns = ( |
| session.query( |
| DagRun, |
| ) |
| .filter( |
| DagRun.dag_id == dag_id, |
| ) |
| .all() |
| ) |
| |
| assert len(dagruns) == 1 |
| dagrun: DagRun = dagruns[0] |
| assert dagrun.state == dag_run_state |
| |
| @pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING]) |
| def test_clear_set_dagrun_state_for_mapped_task(self, dag_run_state): |
| dag_id = "test_clear_set_dagrun_state" |
| self._clean_up(dag_id) |
| task_id = "t1" |
| |
| dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) |
| |
| @dag.task |
| def make_arg_lists(): |
| return [[1], [2], [{"a": "b"}]] |
| |
| def consumer(value): |
| print(value) |
| |
| mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand( |
| op_args=make_arg_lists() |
| ) |
| |
| session = settings.Session() |
| dagrun_1 = dag.create_dagrun( |
| run_type=DagRunType.BACKFILL_JOB, |
| state=State.FAILED, |
| start_date=DEFAULT_DATE, |
| execution_date=DEFAULT_DATE, |
| session=session, |
| ) |
| expand_mapped_task(mapped, dagrun_1.run_id, "make_arg_lists", length=2, session=session) |
| |
| upstream_ti = dagrun_1.get_task_instance("make_arg_lists", session=session) |
| ti = dagrun_1.get_task_instance(task_id, map_index=0, session=session) |
| ti2 = dagrun_1.get_task_instance(task_id, map_index=1, session=session) |
| upstream_ti.state = State.SUCCESS |
| ti.state = State.SUCCESS |
| ti2.state = State.SUCCESS |
| session.flush() |
| |
| dag.clear( |
| task_ids=[(task_id, 0), ("make_arg_lists")], |
| start_date=DEFAULT_DATE, |
| end_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| dag_run_state=dag_run_state, |
| include_subdags=False, |
| include_parentdag=False, |
| session=session, |
| ) |
| session.refresh(upstream_ti) |
| session.refresh(ti) |
| session.refresh(ti2) |
| assert upstream_ti.state is None # cleared |
| assert ti.state is None # cleared |
| assert ti2.state == State.SUCCESS # not cleared |
| dagruns = ( |
| session.query( |
| DagRun, |
| ) |
| .filter( |
| DagRun.dag_id == dag_id, |
| ) |
| .all() |
| ) |
| |
| assert len(dagruns) == 1 |
| dagrun: DagRun = dagruns[0] |
| assert dagrun.state == dag_run_state |
| |
| def test_dag_test_basic(self): |
| dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE) |
| mock_object = mock.MagicMock() |
| |
| @task_decorator |
| def check_task(): |
| # we call a mock object to ensure that this task actually ran. |
| mock_object() |
| |
| with dag: |
| check_task() |
| |
| dag.test() |
| mock_object.assert_called_once() |
| |
| def test_dag_test_with_dependencies(self): |
| dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE) |
| mock_object = mock.MagicMock() |
| |
| @task_decorator |
| def check_task(): |
| return "output of first task" |
| |
| @task_decorator |
| def check_task_2(my_input): |
| # we call a mock object to ensure that this task actually ran. |
| mock_object(my_input) |
| |
| with dag: |
| check_task_2(check_task()) |
| |
| dag.test() |
| mock_object.assert_called_with("output of first task") |
| |
| def test_dag_test_with_fail_handler(self): |
| mock_handle_object_1 = mock.MagicMock() |
| mock_handle_object_2 = mock.MagicMock() |
| |
| def handle_task_failure(context): |
| ti = context["task_instance"] |
| mock_handle_object_1(f"task {ti.task_id} failed...") |
| |
| def handle_dag_failure(context): |
| ti = context["task_instance"] |
| mock_handle_object_2(f"dag {ti.dag_id} run failed...") |
| |
| dag = DAG( |
| dag_id="test_local_testing_conn_file", |
| default_args={"on_failure_callback": handle_task_failure}, |
| on_failure_callback=handle_dag_failure, |
| start_date=DEFAULT_DATE, |
| ) |
| |
| mock_task_object_1 = mock.MagicMock() |
| mock_task_object_2 = mock.MagicMock() |
| |
| @task_decorator |
| def check_task(): |
| mock_task_object_1() |
| raise AirflowException("boooom") |
| |
| @task_decorator |
| def check_task_2(my_input): |
| # we call a mock object to ensure that this task actually ran. |
| mock_task_object_2(my_input) |
| |
| with dag: |
| check_task_2(check_task()) |
| |
| dag.test() |
| |
| mock_handle_object_1.assert_called_with("task check_task failed...") |
| mock_handle_object_2.assert_called_with("dag test_local_testing_conn_file run failed...") |
| mock_task_object_1.assert_called() |
| mock_task_object_2.assert_not_called() |
| |
| def test_dag_test_with_task_mapping(self): |
| dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE) |
| mock_object = mock.MagicMock() |
| |
| @task_decorator() |
| def get_index(current_val, ti=None): |
| return ti.map_index |
| |
| @task_decorator |
| def check_task(my_input): |
| # we call a mock object with the combined map to ensure all expected indexes are called |
| mock_object(list(my_input)) |
| |
| with dag: |
| mapped_task = get_index.expand(current_val=[1, 1, 1, 1, 1]) |
| check_task(mapped_task) |
| |
| dag.test() |
| mock_object.assert_called_with([0, 1, 2, 3, 4]) |
| |
| def test_dag_connection_file(self, tmp_path): |
| test_connections_string = """ |
| --- |
| my_postgres_conn: |
| - conn_id: my_postgres_conn |
| conn_type: postgres |
| """ |
| dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE) |
| |
| @task_decorator |
| def check_task(): |
| from airflow.configuration import secrets_backend_list |
| from airflow.secrets.local_filesystem import LocalFilesystemBackend |
| |
| assert isinstance(secrets_backend_list[0], LocalFilesystemBackend) |
| local_secrets: LocalFilesystemBackend = secrets_backend_list[0] |
| assert local_secrets.get_connection("my_postgres_conn").conn_id == "my_postgres_conn" |
| |
| with dag: |
| check_task() |
| path = tmp_path / "testfile.yaml" |
| path.write_text(test_connections_string) |
| dag.test(conn_file_path=os.fspath(path)) |
| |
| def _make_test_subdag(self, session): |
| dag_id = "test_subdag" |
| self._clean_up(dag_id) |
| task_id = "t1" |
| dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) |
| t_1 = EmptyOperator(task_id=task_id, dag=dag) |
| subdag = DAG(dag_id + ".test", start_date=DEFAULT_DATE, max_active_runs=1) |
| SubDagOperator(task_id="test", subdag=subdag, dag=dag) |
| t_2 = EmptyOperator(task_id="task", dag=subdag) |
| subdag.parent_dag = dag |
| |
| dag.sync_to_db() |
| |
| session = settings.Session() |
| dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| state=State.FAILED, |
| start_date=DEFAULT_DATE, |
| execution_date=DEFAULT_DATE, |
| session=session, |
| ) |
| subdag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| state=State.FAILED, |
| start_date=DEFAULT_DATE, |
| execution_date=DEFAULT_DATE, |
| session=session, |
| ) |
| task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING) |
| task_instance_2 = TI(t_2, execution_date=DEFAULT_DATE, state=State.RUNNING) |
| session.merge(task_instance_1) |
| session.merge(task_instance_2) |
| |
| return dag, subdag |
| |
| @pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING]) |
| def test_clear_set_dagrun_state_for_subdag(self, dag_run_state): |
| session = settings.Session() |
| dag, subdag = self._make_test_subdag(session) |
| session.flush() |
| |
| dag.clear( |
| start_date=DEFAULT_DATE, |
| end_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| dag_run_state=dag_run_state, |
| include_subdags=True, |
| include_parentdag=False, |
| session=session, |
| ) |
| |
| dagrun = ( |
| session.query( |
| DagRun, |
| ) |
| .filter(DagRun.dag_id == subdag.dag_id) |
| .one() |
| ) |
| assert dagrun.state == dag_run_state |
| session.rollback() |
| |
| @pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING]) |
| def test_clear_set_dagrun_state_for_parent_dag(self, dag_run_state): |
| session = settings.Session() |
| dag, subdag = self._make_test_subdag(session) |
| session.flush() |
| |
| subdag.clear( |
| start_date=DEFAULT_DATE, |
| end_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| dag_run_state=dag_run_state, |
| include_subdags=True, |
| include_parentdag=True, |
| session=session, |
| ) |
| |
| dagrun = ( |
| session.query( |
| DagRun, |
| ) |
| .filter(DagRun.dag_id == dag.dag_id) |
| .one() |
| ) |
| assert dagrun.state == dag_run_state |
| |
| @pytest.mark.parametrize( |
| "ti_state_begin, ti_state_end", |
| [ |
| *((state, None) for state in State.task_states if state != TaskInstanceState.RUNNING), |
| (TaskInstanceState.RUNNING, TaskInstanceState.RESTARTING), |
| ], |
| ) |
| def test_clear_dag( |
| self, |
| ti_state_begin: TaskInstanceState | None, |
| ti_state_end: TaskInstanceState | None, |
| ): |
| dag_id = "test_clear_dag" |
| self._clean_up(dag_id) |
| task_id = "t1" |
| dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) |
| t_1 = EmptyOperator(task_id=task_id, dag=dag) |
| |
| session = settings.Session() # type: ignore |
| dagrun_1 = dag.create_dagrun( |
| run_type=DagRunType.BACKFILL_JOB, |
| state=DagRunState.RUNNING, |
| start_date=DEFAULT_DATE, |
| execution_date=DEFAULT_DATE, |
| ) |
| session.merge(dagrun_1) |
| |
| task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=ti_state_begin) |
| task_instance_1.job_id = 123 |
| session.merge(task_instance_1) |
| session.commit() |
| |
| dag.clear( |
| start_date=DEFAULT_DATE, |
| end_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| session=session, |
| ) |
| |
| task_instances = ( |
| session.query( |
| TI, |
| ) |
| .filter( |
| TI.dag_id == dag_id, |
| ) |
| .all() |
| ) |
| |
| assert len(task_instances) == 1 |
| task_instance: TI = task_instances[0] |
| assert task_instance.state == ti_state_end |
| self._clean_up(dag_id) |
| |
| def test_next_dagrun_info_once(self): |
| dag = DAG("test_scheduler_dagrun_once", start_date=timezone.datetime(2015, 1, 1), schedule="@once") |
| |
| next_info = dag.next_dagrun_info(None) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2015, 1, 1) |
| |
| next_info = dag.next_dagrun_info(next_info.data_interval) |
| assert next_info is None |
| |
| def test_next_dagrun_info_start_end_dates(self): |
| """ |
| Tests that an attempt to schedule a task after the Dag's end_date |
| does not succeed. |
| """ |
| delta = datetime.timedelta(hours=1) |
| runs = 3 |
| start_date = DEFAULT_DATE |
| end_date = start_date + (runs - 1) * delta |
| dag_id = "test_schedule_dag_start_end_dates" |
| dag = DAG(dag_id=dag_id, start_date=start_date, end_date=end_date, schedule=delta) |
| dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake")) |
| |
| # Create and schedule the dag runs |
| dates = [] |
| interval = None |
| for _ in range(runs): |
| next_info = dag.next_dagrun_info(interval) |
| if next_info is None: |
| dates.append(None) |
| else: |
| interval = next_info.data_interval |
| dates.append(interval.start) |
| |
| assert all(date is not None for date in dates) |
| assert dates[-1] == end_date |
| assert dag.next_dagrun_info(interval.start) is None |
| |
| def test_next_dagrun_info_catchup(self): |
| """ |
| Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date |
| """ |
| |
| def make_dag(dag_id, schedule, start_date, catchup): |
| default_args = { |
| "owner": "airflow", |
| "depends_on_past": False, |
| } |
| dag = DAG( |
| dag_id, |
| schedule=schedule, |
| start_date=start_date, |
| catchup=catchup, |
| default_args=default_args, |
| ) |
| |
| op1 = EmptyOperator(task_id="t1", dag=dag) |
| op2 = EmptyOperator(task_id="t2", dag=dag) |
| op3 = EmptyOperator(task_id="t3", dag=dag) |
| op1 >> op2 >> op3 |
| |
| return dag |
| |
| now = timezone.utcnow() |
| six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( |
| minute=0, second=0, microsecond=0 |
| ) |
| half_an_hour_ago = now - datetime.timedelta(minutes=30) |
| two_hours_ago = now - datetime.timedelta(hours=2) |
| |
| dag1 = make_dag( |
| dag_id="dag_without_catchup_ten_minute", |
| schedule="*/10 * * * *", |
| start_date=six_hours_ago_to_the_hour, |
| catchup=False, |
| ) |
| next_date, _ = dag1.next_dagrun_info(None) |
| # The DR should be scheduled in the last half an hour, not 6 hours ago |
| assert next_date > half_an_hour_ago |
| assert next_date < timezone.utcnow() |
| |
| dag2 = make_dag( |
| dag_id="dag_without_catchup_hourly", |
| schedule="@hourly", |
| start_date=six_hours_ago_to_the_hour, |
| catchup=False, |
| ) |
| |
| next_date, _ = dag2.next_dagrun_info(None) |
| # The DR should be scheduled in the last 2 hours, not 6 hours ago |
| assert next_date > two_hours_ago |
| # The DR should be scheduled BEFORE now |
| assert next_date < timezone.utcnow() |
| |
| dag3 = make_dag( |
| dag_id="dag_without_catchup_once", |
| schedule="@once", |
| start_date=six_hours_ago_to_the_hour, |
| catchup=False, |
| ) |
| |
| next_date, _ = dag3.next_dagrun_info(None) |
| # The DR should be scheduled in the last 2 hours, not 6 hours ago |
| assert next_date == six_hours_ago_to_the_hour |
| |
| @time_machine.travel(timezone.datetime(2020, 1, 5)) |
| @pytest.mark.parametrize("schedule", ("@daily", timedelta(days=1), cron_timetable("0 0 * * *"))) |
| def test_next_dagrun_info_timedelta_schedule_and_catchup_false(self, schedule): |
| """ |
| Test that the dag file processor does not create multiple dagruns |
| if a dag is scheduled with 'timedelta' and catchup=False |
| """ |
| dag = DAG( |
| "test_scheduler_dagrun_once_with_timedelta_and_catchup_false", |
| start_date=timezone.datetime(2015, 1, 1), |
| schedule=schedule, |
| catchup=False, |
| ) |
| |
| next_info = dag.next_dagrun_info(None) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2020, 1, 4) |
| |
| # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" |
| next_info = dag.next_dagrun_info(next_info.data_interval) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2020, 1, 5) |
| |
| @time_machine.travel(timezone.datetime(2020, 5, 4)) |
| def test_next_dagrun_info_timedelta_schedule_and_catchup_true(self): |
| """ |
| Test that the dag file processor creates multiple dagruns |
| if a dag is scheduled with 'timedelta' and catchup=True |
| """ |
| dag = DAG( |
| "test_scheduler_dagrun_once_with_timedelta_and_catchup_true", |
| start_date=timezone.datetime(2020, 5, 1), |
| schedule=timedelta(days=1), |
| catchup=True, |
| ) |
| |
| next_info = dag.next_dagrun_info(None) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2020, 5, 1) |
| |
| next_info = dag.next_dagrun_info(next_info.data_interval) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2020, 5, 2) |
| |
| next_info = dag.next_dagrun_info(next_info.data_interval) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2020, 5, 3) |
| |
| # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" |
| next_info = dag.next_dagrun_info(next_info.data_interval) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2020, 5, 4) |
| |
| def test_next_dagrun_info_timetable_exception(self, caplog): |
| """Test the DAG does not crash the scheduler if the timetable raises an exception.""" |
| |
| class FailingTimetable(Timetable): |
| def next_dagrun_info(self, last_automated_data_interval, restriction): |
| raise RuntimeError("this fails") |
| |
| dag = DAG( |
| "test_next_dagrun_info_timetable_exception", |
| start_date=timezone.datetime(2020, 5, 1), |
| timetable=FailingTimetable(), |
| catchup=True, |
| ) |
| |
| def _check_logs(records: list[logging.LogRecord], data_interval: DataInterval) -> None: |
| assert len(records) == 1 |
| record = records[0] |
| assert record.exc_info is not None, "Should contain exception" |
| assert record.getMessage() == ( |
| f"Failed to fetch run info after data interval {data_interval} " |
| f"for DAG 'test_next_dagrun_info_timetable_exception'" |
| ) |
| |
| with caplog.at_level(level=logging.ERROR): |
| next_info = dag.next_dagrun_info(None) |
| assert next_info is None, "failed next_dagrun_info should return None" |
| _check_logs(caplog.records, data_interval=None) |
| caplog.clear() |
| data_interval = DataInterval(timezone.datetime(2020, 5, 1), timezone.datetime(2020, 5, 2)) |
| with caplog.at_level(level=logging.ERROR): |
| next_info = dag.next_dagrun_info(data_interval) |
| assert next_info is None, "failed next_dagrun_info should return None" |
| _check_logs(caplog.records, data_interval) |
| |
| def test_next_dagrun_after_auto_align(self): |
| """ |
| Test if the schedule_interval will be auto aligned with the start_date |
| such that if the start_date coincides with the schedule the first |
| execution_date will be start_date, otherwise it will be start_date + |
| interval. |
| """ |
| dag = DAG( |
| dag_id="test_scheduler_auto_align_1", |
| start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), |
| schedule="4 5 * * *", |
| ) |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| next_info = dag.next_dagrun_info(None) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2016, 1, 2, 5, 4) |
| |
| dag = DAG( |
| dag_id="test_scheduler_auto_align_2", |
| start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), |
| schedule="10 10 * * *", |
| ) |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| next_info = dag.next_dagrun_info(None) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2016, 1, 1, 10, 10) |
| |
| def test_next_dagrun_after_not_for_subdags(self): |
| """ |
| Test the subdags are never marked to have dagruns created, as they are |
| handled by the SubDagOperator, not the scheduler |
| """ |
| |
| def subdag(parent_dag_name, child_dag_name, args): |
| """ |
| Create a subdag. |
| """ |
| dag_subdag = DAG( |
| dag_id=f"{parent_dag_name}.{child_dag_name}", |
| schedule="@daily", |
| default_args=args, |
| ) |
| |
| for i in range(2): |
| EmptyOperator(task_id=f"{child_dag_name}-task-{i + 1}", dag=dag_subdag) |
| |
| return dag_subdag |
| |
| with DAG( |
| dag_id="test_subdag_operator", |
| start_date=datetime.datetime(2019, 1, 1), |
| max_active_runs=1, |
| schedule=timedelta(minutes=1), |
| ) as dag: |
| section_1 = SubDagOperator( |
| task_id="section-1", |
| subdag=subdag(dag.dag_id, "section-1", {"start_date": dag.start_date}), |
| ) |
| |
| subdag = section_1.subdag |
| # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set. |
| subdag.parent_dag = dag |
| |
| next_parent_info = dag.next_dagrun_info(None) |
| assert next_parent_info.logical_date == timezone.datetime(2019, 1, 1, 0, 0) |
| |
| next_subdag_info = subdag.next_dagrun_info(None) |
| assert next_subdag_info is None, "SubDags should never have DagRuns created by the scheduler" |
| |
| def test_next_dagrun_info_on_29_feb(self): |
| dag = DAG( |
| "test_scheduler_dagrun_29_feb", start_date=timezone.datetime(2024, 1, 1), schedule="0 0 29 2 *" |
| ) |
| |
| next_info = dag.next_dagrun_info(None) |
| assert next_info |
| assert next_info.logical_date == timezone.datetime(2024, 2, 29) |
| |
| next_info = dag.next_dagrun_info(next_info.data_interval) |
| assert next_info.logical_date == timezone.datetime(2028, 2, 29) |
| assert next_info.data_interval.start == timezone.datetime(2028, 2, 29) |
| assert next_info.data_interval.end == timezone.datetime(2032, 2, 29) |
| |
| def test_replace_outdated_access_control_actions(self): |
| outdated_permissions = { |
| "role1": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}, |
| "role2": {permissions.DEPRECATED_ACTION_CAN_DAG_READ, permissions.DEPRECATED_ACTION_CAN_DAG_EDIT}, |
| } |
| updated_permissions = { |
| "role1": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}, |
| "role2": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}, |
| } |
| |
| with pytest.warns(DeprecationWarning): |
| dag = DAG(dag_id="dag_with_outdated_perms", access_control=outdated_permissions) |
| assert dag.access_control == updated_permissions |
| |
| with pytest.warns(DeprecationWarning): |
| dag.access_control = outdated_permissions |
| assert dag.access_control == updated_permissions |
| |
| def test_validate_params_on_trigger_dag(self): |
| dag = DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")}) |
| with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"): |
| dag.create_dagrun( |
| run_id="test_dagrun_missing_param", |
| state=State.RUNNING, |
| execution_date=TEST_DATE, |
| ) |
| |
| dag = DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")}) |
| with pytest.raises( |
| ParamValidationError, match="Invalid input for param param1: None is not of type 'string'" |
| ): |
| dag.create_dagrun( |
| run_id="test_dagrun_missing_param", |
| state=State.RUNNING, |
| execution_date=TEST_DATE, |
| conf={"param1": None}, |
| ) |
| |
| dag = DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")}) |
| dag.create_dagrun( |
| run_id="test_dagrun_missing_param", |
| state=State.RUNNING, |
| execution_date=TEST_DATE, |
| conf={"param1": "hello"}, |
| ) |
| |
| def test_return_date_range_with_num_method(self): |
| start_date = TEST_DATE |
| delta = timedelta(days=1) |
| |
| dag = DAG("dummy-dag", schedule=delta, start_date=start_date) |
| dag_dates = dag.date_range(start_date=start_date, num=3) |
| |
| assert dag_dates == [ |
| start_date, |
| start_date + delta, |
| start_date + 2 * delta, |
| ] |
| |
| def test_dag_owner_links(self): |
| dag = DAG( |
| "dag", |
| start_date=DEFAULT_DATE, |
| owner_links={"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"}, |
| ) |
| |
| assert dag.owner_links == {"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"} |
| session = settings.Session() |
| dag.sync_to_db(session=session) |
| |
| expected_owners = {"dag": {"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"}} |
| orm_dag_owners = DagOwnerAttributes.get_all(session) |
| assert orm_dag_owners == expected_owners |
| |
| # Test dag owner links are removed completely |
| dag = DAG( |
| "dag", |
| start_date=DEFAULT_DATE, |
| ) |
| dag.sync_to_db(session=session) |
| |
| orm_dag_owners = session.query(DagOwnerAttributes).all() |
| assert not orm_dag_owners |
| |
| # Check wrong formatted owner link |
| with pytest.raises(AirflowException): |
| DAG("dag", start_date=DEFAULT_DATE, owner_links={"owner1": "my-bad-link"}) |
| |
| @pytest.mark.parametrize( |
| "kwargs", |
| [ |
| {"schedule_interval": "@daily", "schedule": "@weekly"}, |
| {"timetable": NullTimetable(), "schedule": "@weekly"}, |
| {"timetable": NullTimetable(), "schedule_interval": "@daily"}, |
| ], |
| ids=[ |
| "schedule_interval+schedule", |
| "timetable+schedule", |
| "timetable+schedule_interval", |
| ], |
| ) |
| def test_schedule_dag_param(self, kwargs): |
| with pytest.raises(ValueError, match="At most one"): |
| with DAG(dag_id="hello", start_date=TEST_DATE, **kwargs): |
| pass |
| |
| def test_continuous_schedule_interval_linmits_max_active_runs(self): |
| dag = DAG("continuous", start_date=DEFAULT_DATE, schedule_interval="@continuous", max_active_runs=1) |
| assert isinstance(dag.timetable, ContinuousTimetable) |
| assert dag.max_active_runs == 1 |
| |
| dag = DAG("continuous", start_date=DEFAULT_DATE, schedule_interval="@continuous", max_active_runs=0) |
| assert isinstance(dag.timetable, ContinuousTimetable) |
| assert dag.max_active_runs == 0 |
| |
| with pytest.raises(AirflowException): |
| dag = DAG( |
| "continuous", start_date=DEFAULT_DATE, schedule_interval="@continuous", max_active_runs=25 |
| ) |
| |
| |
| class TestDagModel: |
| def _clean(self): |
| clear_db_dags() |
| clear_db_datasets() |
| clear_db_runs() |
| |
| def setup_method(self): |
| self._clean() |
| |
| def teardown_method(self): |
| self._clean() |
| |
| def test_dags_needing_dagruns_not_too_early(self): |
| dag = DAG(dag_id="far_future_dag", start_date=timezone.datetime(2038, 1, 1)) |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| session = settings.Session() |
| orm_dag = DagModel( |
| dag_id=dag.dag_id, |
| max_active_tasks=1, |
| has_task_concurrency_limits=False, |
| next_dagrun=dag.start_date, |
| next_dagrun_create_after=timezone.datetime(2038, 1, 2), |
| is_active=True, |
| ) |
| session.add(orm_dag) |
| session.flush() |
| |
| query, _ = DagModel.dags_needing_dagruns(session) |
| dag_models = query.all() |
| assert dag_models == [] |
| |
| session.rollback() |
| session.close() |
| |
| def test_dags_needing_dagruns_datasets(self, dag_maker, session): |
| dataset = Dataset(uri="hello") |
| with dag_maker( |
| session=session, |
| dag_id="my_dag", |
| max_active_runs=1, |
| schedule=[dataset], |
| start_date=pendulum.now().add(days=-2), |
| ) as dag: |
| EmptyOperator(task_id="dummy") |
| |
| # there's no queue record yet, so no runs needed at this time. |
| query, _ = DagModel.dags_needing_dagruns(session) |
| dag_models = query.all() |
| assert dag_models == [] |
| |
| # add queue records so we'll need a run |
| dag_model = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).one() |
| dataset_model: DatasetModel = dag_model.schedule_datasets[0] |
| session.add(DatasetDagRunQueue(dataset_id=dataset_model.id, target_dag_id=dag_model.dag_id)) |
| session.flush() |
| query, _ = DagModel.dags_needing_dagruns(session) |
| dag_models = query.all() |
| assert dag_models == [dag_model] |
| |
| # create run so we don't need a run anymore (due to max active runs) |
| dag_maker.create_dagrun( |
| run_type=DagRunType.DATASET_TRIGGERED, |
| state=DagRunState.QUEUED, |
| execution_date=pendulum.now("UTC"), |
| ) |
| query, _ = DagModel.dags_needing_dagruns(session) |
| dag_models = query.all() |
| assert dag_models == [] |
| |
| # increase max active runs and we should now need another run |
| dag_maker.dag_model.max_active_runs = 2 |
| session.flush() |
| query, _ = DagModel.dags_needing_dagruns(session) |
| dag_models = query.all() |
| assert dag_models == [dag_model] |
| |
| def test_max_active_runs_not_none(self): |
| dag = DAG(dag_id="test_max_active_runs_not_none", start_date=timezone.datetime(2038, 1, 1)) |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| session = settings.Session() |
| orm_dag = DagModel( |
| dag_id=dag.dag_id, |
| has_task_concurrency_limits=False, |
| next_dagrun=None, |
| next_dagrun_create_after=None, |
| is_active=True, |
| ) |
| # assert max_active_runs updated |
| assert orm_dag.max_active_runs == 16 |
| session.add(orm_dag) |
| session.flush() |
| assert orm_dag.max_active_runs is not None |
| |
| session.rollback() |
| session.close() |
| |
| def test_dags_needing_dagruns_only_unpaused(self): |
| """ |
| We should never create dagruns for unpaused DAGs |
| """ |
| dag = DAG(dag_id="test_dags", start_date=DEFAULT_DATE) |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| session = settings.Session() |
| orm_dag = DagModel( |
| dag_id=dag.dag_id, |
| has_task_concurrency_limits=False, |
| next_dagrun=DEFAULT_DATE, |
| next_dagrun_create_after=DEFAULT_DATE + timedelta(days=1), |
| is_active=True, |
| ) |
| session.add(orm_dag) |
| session.flush() |
| |
| query, _ = DagModel.dags_needing_dagruns(session) |
| needed = query.all() |
| assert needed == [orm_dag] |
| |
| orm_dag.is_paused = True |
| session.flush() |
| |
| query, _ = DagModel.dags_needing_dagruns(session) |
| dag_models = query.all() |
| assert dag_models == [] |
| |
| session.rollback() |
| session.close() |
| |
| def test_dags_needing_dagruns_doesnot_send_dagmodel_with_import_errors(self, session): |
| """ |
| We check that has_import_error is false for dags |
| being set to scheduler to create dagruns |
| """ |
| dag = DAG(dag_id="test_dags", start_date=DEFAULT_DATE) |
| EmptyOperator(task_id="dummy", dag=dag, owner="airflow") |
| |
| orm_dag = DagModel( |
| dag_id=dag.dag_id, |
| has_task_concurrency_limits=False, |
| next_dagrun=DEFAULT_DATE, |
| next_dagrun_create_after=DEFAULT_DATE + timedelta(days=1), |
| is_active=True, |
| ) |
| assert not orm_dag.has_import_errors |
| session.add(orm_dag) |
| session.flush() |
| |
| query, _ = DagModel.dags_needing_dagruns(session) |
| needed = query.all() |
| assert needed == [orm_dag] |
| orm_dag.has_import_errors = True |
| session.merge(orm_dag) |
| session.flush() |
| query, _ = DagModel.dags_needing_dagruns(session) |
| needed = query.all() |
| assert needed == [] |
| |
| @pytest.mark.parametrize( |
| ("fileloc", "expected_relative"), |
| [ |
| (os.path.join(settings.DAGS_FOLDER, "a.py"), Path("a.py")), |
| ("/tmp/foo.py", Path("/tmp/foo.py")), |
| ], |
| ) |
| def test_relative_fileloc(self, fileloc, expected_relative): |
| dag = DAG(dag_id="test") |
| dag.fileloc = fileloc |
| |
| assert dag.relative_fileloc == expected_relative |
| |
| @pytest.mark.parametrize( |
| "reader_dags_folder", [settings.DAGS_FOLDER, str(repo_root / "airflow/example_dags")] |
| ) |
| @pytest.mark.parametrize( |
| ("fileloc", "expected_relative"), |
| [ |
| (str(Path(settings.DAGS_FOLDER, "a.py")), Path("a.py")), |
| ("/tmp/foo.py", Path("/tmp/foo.py")), |
| ], |
| ) |
| def test_relative_fileloc_serialized( |
| self, fileloc, expected_relative, session, clear_dags, reader_dags_folder |
| ): |
| """ |
| The serialized dag model includes the dags folder as configured on the thing serializing |
| the dag. On the thing deserializing the dag, when determining relative fileloc, |
| we should use the dags folder of the processor. So even if the dags folder of |
| the deserializer is different (meaning that the full path is no longer relative to |
| the dags folder) then we should still get the relative fileloc as it existed on the |
| serializer process. When the full path is not relative to the configured dags folder, |
| then relative fileloc should just be the full path. |
| """ |
| dag = DAG(dag_id="test") |
| dag.fileloc = fileloc |
| sdm = SerializedDagModel(dag) |
| session.add(sdm) |
| session.commit() |
| session.expunge_all() |
| sdm = SerializedDagModel.get(dag.dag_id, session) |
| dag = sdm.dag |
| with conf_vars({("core", "dags_folder"): reader_dags_folder}): |
| assert dag.relative_fileloc == expected_relative |
| |
| def test__processor_dags_folder(self, session): |
| """Only populated after deserializtion""" |
| dag = DAG(dag_id="test") |
| dag.fileloc = "/abc/test.py" |
| assert dag._processor_dags_folder is None |
| sdm = SerializedDagModel(dag) |
| assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER |
| |
| @pytest.mark.need_serialized_dag |
| def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, session, dag_maker): |
| dataset1 = Dataset(uri="ds1") |
| dataset2 = Dataset(uri="ds2") |
| |
| for dag_id, dataset in [("datasets-1", dataset1), ("datasets-2", dataset2)]: |
| with dag_maker(dag_id=dag_id, start_date=timezone.utcnow(), session=session): |
| EmptyOperator(task_id="task", outlets=[dataset]) |
| dr = dag_maker.create_dagrun() |
| |
| ds_id = session.query(DatasetModel.id).filter_by(uri=dataset.uri).scalar() |
| |
| session.add( |
| DatasetEvent( |
| dataset_id=ds_id, |
| source_task_id="task", |
| source_dag_id=dr.dag_id, |
| source_run_id=dr.run_id, |
| source_map_index=-1, |
| ) |
| ) |
| |
| ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar() |
| ds2_id = session.query(DatasetModel.id).filter_by(uri=dataset2.uri).scalar() |
| |
| with dag_maker(dag_id="datasets-consumer-multiple", schedule=[dataset1, dataset2]) as dag: |
| pass |
| |
| session.flush() |
| session.add_all( |
| [ |
| DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag.dag_id, created_at=DEFAULT_DATE), |
| DatasetDagRunQueue( |
| dataset_id=ds2_id, target_dag_id=dag.dag_id, created_at=DEFAULT_DATE + timedelta(hours=1) |
| ), |
| ] |
| ) |
| session.flush() |
| |
| query, dataset_triggered_dag_info = DagModel.dags_needing_dagruns(session) |
| assert 1 == len(dataset_triggered_dag_info) |
| assert dag.dag_id in dataset_triggered_dag_info |
| first_queued_time, last_queued_time = dataset_triggered_dag_info[dag.dag_id] |
| assert first_queued_time == DEFAULT_DATE |
| assert last_queued_time == DEFAULT_DATE + timedelta(hours=1) |
| |
| def test_dataset_expression(self, session: Session) -> None: |
| dag = DAG( |
| dag_id="test_dag_dataset_expression", |
| schedule=DatasetAny( |
| Dataset("s3://dag1/output_1.txt", {"hi": "bye"}), |
| DatasetAll( |
| Dataset("s3://dag2/output_1.txt", {"hi": "bye"}), |
| Dataset("s3://dag3/output_3.txt", {"hi": "bye"}), |
| ), |
| ), |
| start_date=datetime.datetime.min, |
| ) |
| DAG.bulk_write_to_db([dag], session=session) |
| |
| expression = session.scalars(select(DagModel.dataset_expression).filter_by(dag_id=dag.dag_id)).one() |
| assert expression == { |
| "any": [ |
| "s3://dag1/output_1.txt", |
| {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, |
| ] |
| } |
| |
| |
| class TestQueries: |
| def setup_method(self) -> None: |
| clear_db_runs() |
| |
| def teardown_method(self) -> None: |
| clear_db_runs() |
| |
| @pytest.mark.parametrize("tasks_count", [3, 12]) |
| def test_count_number_queries(self, tasks_count): |
| dag = DAG("test_dagrun_query_count", start_date=DEFAULT_DATE) |
| for i in range(tasks_count): |
| EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag) |
| with assert_queries_count(2): |
| dag.create_dagrun( |
| run_id="test_dagrun_query_count", |
| state=State.RUNNING, |
| execution_date=TEST_DATE, |
| ) |
| |
| |
| class TestDagDecorator: |
| DEFAULT_ARGS = { |
| "owner": "test", |
| "depends_on_past": True, |
| "start_date": timezone.utcnow(), |
| "retries": 1, |
| "retry_delay": timedelta(minutes=1), |
| } |
| DEFAULT_DATE = timezone.datetime(2016, 1, 1) |
| VALUE = 42 |
| |
| def setup_method(self): |
| self.operator = None |
| |
| def teardown_method(self): |
| clear_db_runs() |
| |
| def test_fileloc(self): |
| @dag_decorator(default_args=self.DEFAULT_ARGS) |
| def noop_pipeline(): ... |
| |
| dag = noop_pipeline() |
| assert isinstance(dag, DAG) |
| assert dag.dag_id == "noop_pipeline" |
| assert dag.fileloc == __file__ |
| |
| def test_set_dag_id(self): |
| """Test that checks you can set dag_id from decorator.""" |
| |
| @dag_decorator("test", default_args=self.DEFAULT_ARGS) |
| def noop_pipeline(): ... |
| |
| dag = noop_pipeline() |
| assert isinstance(dag, DAG) |
| assert dag.dag_id == "test" |
| |
| def test_default_dag_id(self): |
| """Test that @dag uses function name as default dag id.""" |
| |
| @dag_decorator(default_args=self.DEFAULT_ARGS) |
| def noop_pipeline(): ... |
| |
| dag = noop_pipeline() |
| assert isinstance(dag, DAG) |
| assert dag.dag_id == "noop_pipeline" |
| |
| @pytest.mark.parametrize( |
| argnames=["dag_doc_md", "expected_doc_md"], |
| argvalues=[ |
| pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"), |
| pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"), |
| ], |
| ) |
| def test_documentation_added(self, dag_doc_md, expected_doc_md): |
| """Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set.""" |
| |
| @dag_decorator(default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md) |
| def noop_pipeline(): |
| """Regular DAG documentation""" |
| |
| dag = noop_pipeline() |
| assert isinstance(dag, DAG) |
| assert dag.dag_id == "noop_pipeline" |
| assert dag.doc_md == expected_doc_md |
| |
| def test_documentation_template_rendered(self): |
| """Test that @dag uses function docs as doc_md for DAG object""" |
| |
| @dag_decorator(default_args=self.DEFAULT_ARGS) |
| def noop_pipeline(): |
| """ |
| {% if True %} |
| Regular DAG documentation |
| {% endif %} |
| """ |
| |
| dag = noop_pipeline() |
| assert isinstance(dag, DAG) |
| assert dag.dag_id == "noop_pipeline" |
| assert "Regular DAG documentation" in dag.doc_md |
| |
| def test_resolve_documentation_template_file_rendered(self, tmp_path): |
| """Test that @dag uses function docs as doc_md for DAG object""" |
| |
| path = tmp_path / "testfile.md" |
| path.write_text( |
| """ |
| {% if True %} |
| External Markdown DAG documentation |
| {% endif %} |
| """ |
| ) |
| |
| @dag_decorator( |
| "test-dag", start_date=DEFAULT_DATE, template_searchpath=os.fspath(path.parent), doc_md=path.name |
| ) |
| def markdown_docs(): ... |
| |
| dag = markdown_docs() |
| assert isinstance(dag, DAG) |
| assert dag.dag_id == "test-dag" |
| assert dag.doc_md.strip() == "External Markdown DAG documentation" |
| |
| def test_fails_if_arg_not_set(self): |
| """Test that @dag decorated function fails if positional argument is not set""" |
| |
| @dag_decorator(default_args=self.DEFAULT_ARGS) |
| def noop_pipeline(value): |
| @task_decorator |
| def return_num(num): |
| return num |
| |
| return_num(value) |
| |
| # Test that if arg is not passed it raises a type error as expected. |
| with pytest.raises(TypeError): |
| noop_pipeline() |
| |
| def test_dag_param_resolves(self): |
| """Test that dag param is correctly resolved by operator""" |
| |
| @dag_decorator(default_args=self.DEFAULT_ARGS) |
| def xcom_pass_to_op(value=self.VALUE): |
| @task_decorator |
| def return_num(num): |
| return num |
| |
| xcom_arg = return_num(value) |
| self.operator = xcom_arg.operator |
| |
| dag = xcom_pass_to_op() |
| |
| dr = dag.create_dagrun( |
| run_id=DagRunType.MANUAL.value, |
| start_date=timezone.utcnow(), |
| execution_date=self.DEFAULT_DATE, |
| data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE), |
| state=State.RUNNING, |
| ) |
| |
| self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE) |
| ti = dr.get_task_instances()[0] |
| assert ti.xcom_pull() == self.VALUE |
| |
| def test_dag_param_dagrun_parameterized(self): |
| """Test that dag param is correctly overwritten when set in dag run""" |
| |
| @dag_decorator(default_args=self.DEFAULT_ARGS) |
| def xcom_pass_to_op(value=self.VALUE): |
| @task_decorator |
| def return_num(num): |
| return num |
| |
| assert isinstance(value, DagParam) |
| |
| xcom_arg = return_num(value) |
| self.operator = xcom_arg.operator |
| |
| dag = xcom_pass_to_op() |
| new_value = 52 |
| dr = dag.create_dagrun( |
| run_id=DagRunType.MANUAL.value, |
| start_date=timezone.utcnow(), |
| execution_date=self.DEFAULT_DATE, |
| data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE), |
| state=State.RUNNING, |
| conf={"value": new_value}, |
| ) |
| |
| self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE) |
| ti = dr.get_task_instances()[0] |
| assert ti.xcom_pull() == new_value |
| |
| @pytest.mark.parametrize("value", [VALUE, 0]) |
| def test_set_params_for_dag(self, value): |
| """Test that dag param is correctly set when using dag decorator""" |
| |
| @dag_decorator(default_args=self.DEFAULT_ARGS) |
| def xcom_pass_to_op(value=value): |
| @task_decorator |
| def return_num(num): |
| return num |
| |
| xcom_arg = return_num(value) |
| self.operator = xcom_arg.operator |
| |
| dag = xcom_pass_to_op() |
| assert dag.params["value"] == value |
| |
| def test_warning_location(self): |
| # NOTE: This only works as long as there is some warning we can emit from `DAG()` |
| @dag_decorator(schedule_interval=None) |
| def mydag(): ... |
| |
| with pytest.warns(RemovedInAirflow3Warning) as warnings: |
| line = sys._getframe().f_lineno + 1 |
| mydag() |
| |
| w = warnings.pop(RemovedInAirflow3Warning) |
| assert w.filename == __file__ |
| assert w.lineno == line |
| |
| |
| @pytest.mark.parametrize("timetable", [NullTimetable(), OnceTimetable()]) |
| def test_dag_timetable_match_schedule_interval(timetable): |
| dag = DAG("my-dag", timetable=timetable, start_date=DEFAULT_DATE) |
| assert dag._check_schedule_interval_matches_timetable() |
| |
| |
| @pytest.mark.parametrize("schedule_interval", [None, "@once", "@daily", timedelta(days=1)]) |
| def test_dag_schedule_interval_match_timetable(schedule_interval): |
| dag = DAG("my-dag", schedule=schedule_interval, start_date=DEFAULT_DATE) |
| assert dag._check_schedule_interval_matches_timetable() |
| |
| |
| @pytest.mark.parametrize("schedule_interval", [None, "@daily", timedelta(days=1)]) |
| def test_dag_schedule_interval_change_after_init(schedule_interval): |
| dag = DAG("my-dag", timetable=OnceTimetable(), start_date=DEFAULT_DATE) |
| dag.schedule_interval = schedule_interval |
| assert not dag._check_schedule_interval_matches_timetable() |
| |
| |
| @pytest.mark.parametrize("timetable", [NullTimetable(), OnceTimetable()]) |
| def test_dag_timetable_change_after_init(timetable): |
| dag = DAG("my-dag") # Default is timedelta(days=1). |
| dag.timetable = timetable |
| assert not dag._check_schedule_interval_matches_timetable() |
| |
| |
| @pytest.mark.parametrize( |
| "run_id, execution_date", |
| [ |
| (None, datetime_tz(2020, 1, 1)), |
| ("test-run-id", None), |
| ], |
| ) |
| def test_set_task_instance_state(run_id, execution_date, session, dag_maker): |
| """Test that set_task_instance_state updates the TaskInstance state and clear downstream failed""" |
| |
| start_date = datetime_tz(2020, 1, 1) |
| with dag_maker("test_set_task_instance_state", start_date=start_date, session=session) as dag: |
| task_1 = EmptyOperator(task_id="task_1") |
| task_2 = EmptyOperator(task_id="task_2") |
| task_3 = EmptyOperator(task_id="task_3") |
| task_4 = EmptyOperator(task_id="task_4") |
| task_5 = EmptyOperator(task_id="task_5") |
| |
| task_1 >> [task_2, task_3, task_4, task_5] |
| |
| dagrun = dag_maker.create_dagrun( |
| run_id=run_id, |
| execution_date=execution_date, |
| state=State.FAILED, |
| run_type=DagRunType.SCHEDULED, |
| ) |
| |
| def get_ti_from_db(task): |
| return ( |
| session.query(TI) |
| .filter( |
| TI.dag_id == dag.dag_id, |
| TI.task_id == task.task_id, |
| TI.run_id == dagrun.run_id, |
| ) |
| .one() |
| ) |
| |
| get_ti_from_db(task_1).state = State.FAILED |
| get_ti_from_db(task_2).state = State.SUCCESS |
| get_ti_from_db(task_3).state = State.UPSTREAM_FAILED |
| get_ti_from_db(task_4).state = State.FAILED |
| get_ti_from_db(task_5).state = State.SKIPPED |
| |
| session.flush() |
| |
| altered = dag.set_task_instance_state( |
| task_id=task_1.task_id, |
| run_id=run_id, |
| execution_date=execution_date, |
| state=State.SUCCESS, |
| session=session, |
| ) |
| |
| # After _mark_task_instance_state, task_1 is marked as SUCCESS |
| ti1 = get_ti_from_db(task_1) |
| assert ti1.state == State.SUCCESS |
| # TIs should have DagRun pre-loaded |
| assert isinstance(inspect(ti1).attrs.dag_run.loaded_value, DagRun) |
| # task_2 remains as SUCCESS |
| assert get_ti_from_db(task_2).state == State.SUCCESS |
| # task_3 and task_4 are cleared because they were in FAILED/UPSTREAM_FAILED state |
| assert get_ti_from_db(task_3).state == State.NONE |
| assert get_ti_from_db(task_4).state == State.NONE |
| # task_5 remains as SKIPPED |
| assert get_ti_from_db(task_5).state == State.SKIPPED |
| dagrun.refresh_from_db(session=session) |
| # dagrun should be set to QUEUED |
| assert dagrun.get_state() == State.QUEUED |
| |
| assert {tuple(t.key) for t in altered} == { |
| ("test_set_task_instance_state", "task_1", dagrun.run_id, 0, -1) |
| } |
| |
| |
| def test_set_task_instance_state_mapped(dag_maker, session): |
| """Test that when setting an individual mapped TI that the other TIs are not affected""" |
| task_id = "t1" |
| |
| with dag_maker(session=session) as dag: |
| |
| @dag.task |
| def make_arg_lists(): |
| return [[1], [2], [{"a": "b"}]] |
| |
| def consumer(value): |
| print(value) |
| |
| mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand( |
| op_args=make_arg_lists() |
| ) |
| |
| mapped >> BaseOperator(task_id="downstream") |
| |
| dr1 = dag_maker.create_dagrun( |
| run_type=DagRunType.SCHEDULED, |
| state=DagRunState.FAILED, |
| ) |
| expand_mapped_task(mapped, dr1.run_id, "make_arg_lists", length=2, session=session) |
| |
| # set_state(future=True) only applies to scheduled runs |
| dr2 = dag_maker.create_dagrun( |
| run_type=DagRunType.SCHEDULED, |
| state=DagRunState.FAILED, |
| execution_date=DEFAULT_DATE + datetime.timedelta(days=1), |
| ) |
| expand_mapped_task(mapped, dr2.run_id, "make_arg_lists", length=2, session=session) |
| |
| session.query(TI).filter_by(dag_id=dag.dag_id).update({"state": TaskInstanceState.FAILED}) |
| |
| ti_query = ( |
| session.query(TI.task_id, TI.map_index, TI.run_id, TI.state) |
| .filter(TI.dag_id == dag.dag_id, TI.task_id.in_([task_id, "downstream"])) |
| .order_by(TI.run_id, TI.task_id, TI.map_index) |
| ) |
| |
| # Check pre-conditions |
| assert ti_query.all() == [ |
| ("downstream", -1, dr1.run_id, TaskInstanceState.FAILED), |
| (task_id, 0, dr1.run_id, TaskInstanceState.FAILED), |
| (task_id, 1, dr1.run_id, TaskInstanceState.FAILED), |
| ("downstream", -1, dr2.run_id, TaskInstanceState.FAILED), |
| (task_id, 0, dr2.run_id, TaskInstanceState.FAILED), |
| (task_id, 1, dr2.run_id, TaskInstanceState.FAILED), |
| ] |
| |
| dag.set_task_instance_state( |
| task_id=task_id, |
| map_indexes=[1], |
| future=True, |
| run_id=dr1.run_id, |
| state=TaskInstanceState.SUCCESS, |
| session=session, |
| ) |
| assert dr1 in session, "Check session is passed down all the way" |
| |
| assert ti_query.all() == [ |
| ("downstream", -1, dr1.run_id, None), |
| (task_id, 0, dr1.run_id, TaskInstanceState.FAILED), |
| (task_id, 1, dr1.run_id, TaskInstanceState.SUCCESS), |
| ("downstream", -1, dr2.run_id, None), |
| (task_id, 0, dr2.run_id, TaskInstanceState.FAILED), |
| (task_id, 1, dr2.run_id, TaskInstanceState.SUCCESS), |
| ] |
| |
| |
| @pytest.mark.parametrize("run_id, execution_date", [(None, datetime_tz(2020, 1, 1)), ("test-run-id", None)]) |
| def test_set_task_group_state(run_id, execution_date, session, dag_maker): |
| """Test that set_task_group_state updates the TaskGroup state and clear downstream failed""" |
| |
| start_date = datetime_tz(2020, 1, 1) |
| with dag_maker("test_set_task_group_state", start_date=start_date, session=session) as dag: |
| start = EmptyOperator(task_id="start") |
| |
| with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1: |
| task_1 = EmptyOperator(task_id="task_1") |
| task_2 = EmptyOperator(task_id="task_2") |
| task_3 = EmptyOperator(task_id="task_3") |
| |
| task_1 >> [task_2, task_3] |
| |
| task_4 = EmptyOperator(task_id="task_4") |
| task_5 = EmptyOperator(task_id="task_5") |
| task_6 = EmptyOperator(task_id="task_6") |
| task_7 = EmptyOperator(task_id="task_7") |
| task_8 = EmptyOperator(task_id="task_8") |
| |
| start >> section_1 >> [task_4, task_5, task_6, task_7, task_8] |
| |
| dagrun = dag_maker.create_dagrun( |
| run_id=run_id, |
| execution_date=execution_date, |
| state=State.FAILED, |
| run_type=DagRunType.SCHEDULED, |
| ) |
| |
| def get_ti_from_db(task): |
| return ( |
| session.query(TI) |
| .filter( |
| TI.dag_id == dag.dag_id, |
| TI.task_id == task.task_id, |
| TI.run_id == dagrun.run_id, |
| ) |
| .one() |
| ) |
| |
| get_ti_from_db(task_1).state = State.FAILED |
| get_ti_from_db(task_2).state = State.SUCCESS |
| get_ti_from_db(task_3).state = State.UPSTREAM_FAILED |
| get_ti_from_db(task_4).state = State.SUCCESS |
| get_ti_from_db(task_5).state = State.UPSTREAM_FAILED |
| get_ti_from_db(task_6).state = State.FAILED |
| get_ti_from_db(task_7).state = State.SKIPPED |
| |
| session.flush() |
| |
| altered = dag.set_task_group_state( |
| group_id=section_1.group_id, |
| run_id=run_id, |
| execution_date=execution_date, |
| state=State.SUCCESS, |
| session=session, |
| ) |
| |
| # After _mark_task_instance_state, task_1 is marked as SUCCESS |
| assert get_ti_from_db(task_1).state == State.SUCCESS |
| # task_2 remains as SUCCESS |
| assert get_ti_from_db(task_2).state == State.SUCCESS |
| # task_3 should be marked as SUCCESS |
| assert get_ti_from_db(task_3).state == State.SUCCESS |
| # task_4 should remain as SUCCESS |
| assert get_ti_from_db(task_4).state == State.SUCCESS |
| # task_5 and task_6 are cleared because they were in FAILED/UPSTREAM_FAILED state |
| assert get_ti_from_db(task_5).state == State.NONE |
| assert get_ti_from_db(task_6).state == State.NONE |
| # task_7 remains as SKIPPED |
| assert get_ti_from_db(task_7).state == State.SKIPPED |
| dagrun.refresh_from_db(session=session) |
| # dagrun should be set to QUEUED |
| assert dagrun.get_state() == State.QUEUED |
| |
| assert {t.key for t in altered} == { |
| ("test_set_task_group_state", "section_1.task_1", dagrun.run_id, 0, -1), |
| ("test_set_task_group_state", "section_1.task_3", dagrun.run_id, 0, -1), |
| } |
| |
| |
| def test_dag_teardowns_property_lists_all_teardown_tasks(dag_maker): |
| @setup |
| def setup_task(): |
| return 1 |
| |
| @teardown |
| def teardown_task(): |
| return 1 |
| |
| @teardown |
| def teardown_task2(): |
| return 1 |
| |
| @teardown |
| def teardown_task3(): |
| return 1 |
| |
| @task_decorator |
| def mytask(): |
| return 1 |
| |
| with dag_maker() as dag: |
| t1 = setup_task() |
| t2 = teardown_task() |
| t3 = teardown_task2() |
| t4 = teardown_task3() |
| with t1 >> t2: |
| with t3: |
| with t4: |
| mytask() |
| |
| assert {t.task_id for t in dag.teardowns} == {"teardown_task", "teardown_task2", "teardown_task3"} |
| assert {t.task_id for t in dag.tasks_upstream_of_teardowns} == {"setup_task", "mytask"} |
| |
| |
| @pytest.mark.parametrize( |
| "start_date, expected_infos", |
| [ |
| ( |
| DEFAULT_DATE, |
| [DagRunInfo.interval(DEFAULT_DATE, DEFAULT_DATE + datetime.timedelta(hours=1))], |
| ), |
| ( |
| DEFAULT_DATE - datetime.timedelta(hours=3), |
| [ |
| DagRunInfo.interval( |
| DEFAULT_DATE - datetime.timedelta(hours=3), |
| DEFAULT_DATE - datetime.timedelta(hours=2), |
| ), |
| DagRunInfo.interval( |
| DEFAULT_DATE - datetime.timedelta(hours=2), |
| DEFAULT_DATE - datetime.timedelta(hours=1), |
| ), |
| DagRunInfo.interval( |
| DEFAULT_DATE - datetime.timedelta(hours=1), |
| DEFAULT_DATE, |
| ), |
| DagRunInfo.interval( |
| DEFAULT_DATE, |
| DEFAULT_DATE + datetime.timedelta(hours=1), |
| ), |
| ], |
| ), |
| ], |
| ids=["in-dag-restriction", "out-of-dag-restriction"], |
| ) |
| def test_iter_dagrun_infos_between(start_date, expected_infos): |
| dag = DAG(dag_id="test_get_dates", start_date=DEFAULT_DATE, schedule="@hourly") |
| EmptyOperator(task_id="dummy", dag=dag) |
| |
| iterator = dag.iter_dagrun_infos_between( |
| earliest=pendulum.instance(start_date), |
| latest=pendulum.instance(DEFAULT_DATE), |
| align=True, |
| ) |
| assert expected_infos == list(iterator) |
| |
| |
| def test_iter_dagrun_infos_between_error(caplog): |
| start = pendulum.instance(DEFAULT_DATE - datetime.timedelta(hours=1)) |
| end = pendulum.instance(DEFAULT_DATE) |
| |
| class FailingAfterOneTimetable(Timetable): |
| def next_dagrun_info(self, last_automated_data_interval, restriction): |
| if last_automated_data_interval is None: |
| return DagRunInfo.interval(start, end) |
| raise RuntimeError("this fails") |
| |
| dag = DAG( |
| dag_id="test_iter_dagrun_infos_between_error", |
| start_date=DEFAULT_DATE, |
| timetable=FailingAfterOneTimetable(), |
| ) |
| |
| iterator = dag.iter_dagrun_infos_between(earliest=start, latest=end, align=True) |
| with caplog.at_level(logging.ERROR): |
| infos = list(iterator) |
| |
| # The second timetable.next_dagrun_info() call raises an exception, so only the first result is returned. |
| assert infos == [DagRunInfo.interval(start, end)] |
| |
| assert caplog.record_tuples == [ |
| ( |
| "airflow.models.dag.DAG", |
| logging.ERROR, |
| f"Failed to fetch run info after data interval {DataInterval(start, end)} for DAG {dag.dag_id!r}", |
| ), |
| ] |
| assert caplog.records[0].exc_info is not None, "should contain exception context" |
| |
| |
| @pytest.mark.parametrize( |
| "logical_date, data_interval_start, data_interval_end, expected_data_interval", |
| [ |
| pytest.param(None, None, None, None, id="no-next-run"), |
| pytest.param( |
| DEFAULT_DATE, |
| DEFAULT_DATE, |
| DEFAULT_DATE + timedelta(days=2), |
| DataInterval(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=2)), |
| id="modern", |
| ), |
| pytest.param( |
| DEFAULT_DATE, |
| None, |
| None, |
| DataInterval(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=1)), |
| id="legacy", |
| ), |
| ], |
| ) |
| def test_get_next_data_interval( |
| logical_date, |
| data_interval_start, |
| data_interval_end, |
| expected_data_interval, |
| ): |
| dag = DAG(dag_id="test_get_next_data_interval", schedule="@daily", start_date=DEFAULT_DATE) |
| dag_model = DagModel( |
| dag_id="test_get_next_data_interval", |
| next_dagrun=logical_date, |
| next_dagrun_data_interval_start=data_interval_start, |
| next_dagrun_data_interval_end=data_interval_end, |
| ) |
| |
| assert dag.get_next_data_interval(dag_model) == expected_data_interval |
| |
| |
| @pytest.mark.parametrize( |
| ("dag_date", "tasks_date", "restrict"), |
| [ |
| [ |
| (DEFAULT_DATE, None), |
| [ |
| (DEFAULT_DATE + timedelta(days=1), DEFAULT_DATE + timedelta(days=2)), |
| (DEFAULT_DATE + timedelta(days=3), DEFAULT_DATE + timedelta(days=4)), |
| ], |
| TimeRestriction(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=4), True), |
| ], |
| [ |
| (DEFAULT_DATE, None), |
| [(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=1)), (DEFAULT_DATE, None)], |
| TimeRestriction(DEFAULT_DATE, None, True), |
| ], |
| ], |
| ) |
| def test__time_restriction(dag_maker, dag_date, tasks_date, restrict): |
| with dag_maker("test__time_restriction", start_date=dag_date[0], end_date=dag_date[1]) as dag: |
| EmptyOperator(task_id="do1", start_date=tasks_date[0][0], end_date=tasks_date[0][1]) |
| EmptyOperator(task_id="do2", start_date=tasks_date[1][0], end_date=tasks_date[1][1]) |
| |
| assert dag._time_restriction == restrict |
| |
| |
| @pytest.mark.parametrize( |
| "tags, should_pass", |
| [ |
| pytest.param([], True, id="empty tags"), |
| pytest.param(["a normal tag"], True, id="one tag"), |
| pytest.param(["a normal tag", "another normal tag"], True, id="two tags"), |
| pytest.param(["a" * 100], True, id="a tag that's of just length 100"), |
| pytest.param(["a normal tag", "a" * 101], False, id="two tags and one of them is of length > 100"), |
| ], |
| ) |
| def test__tags_length(tags: list[str], should_pass: bool): |
| if should_pass: |
| DAG("test-dag", tags=tags) |
| else: |
| with pytest.raises(AirflowException): |
| DAG("test-dag", tags=tags) |
| |
| |
| @pytest.mark.need_serialized_dag |
| def test_get_dataset_triggered_next_run_info(dag_maker, clear_datasets): |
| dataset1 = Dataset(uri="ds1") |
| dataset2 = Dataset(uri="ds2") |
| dataset3 = Dataset(uri="ds3") |
| with dag_maker(dag_id="datasets-1", schedule=[dataset2]): |
| pass |
| dag1 = dag_maker.dag |
| |
| with dag_maker(dag_id="datasets-2", schedule=[dataset1, dataset2]): |
| pass |
| dag2 = dag_maker.dag |
| |
| with dag_maker(dag_id="datasets-3", schedule=[dataset1, dataset2, dataset3]): |
| pass |
| dag3 = dag_maker.dag |
| |
| session = dag_maker.session |
| ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar() |
| session.bulk_save_objects( |
| [ |
| DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag2.dag_id), |
| DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag3.dag_id), |
| ] |
| ) |
| session.flush() |
| |
| datasets = session.query(DatasetModel.uri).order_by(DatasetModel.id).all() |
| |
| info = get_dataset_triggered_next_run_info([dag1.dag_id], session=session) |
| assert info[dag1.dag_id] == { |
| "ready": 0, |
| "total": 1, |
| "uri": datasets[0].uri, |
| } |
| |
| # This time, check both dag2 and dag3 at the same time (tests filtering) |
| info = get_dataset_triggered_next_run_info([dag2.dag_id, dag3.dag_id], session=session) |
| assert info[dag2.dag_id] == { |
| "ready": 1, |
| "total": 2, |
| "uri": "", |
| } |
| assert info[dag3.dag_id] == { |
| "ready": 1, |
| "total": 3, |
| "uri": "", |
| } |
| |
| |
| def test_dag_uses_timetable_for_run_id(session): |
| class CustomRunIdTimetable(Timetable): |
| def generate_run_id(self, *, run_type, logical_date, data_interval, **extra) -> str: |
| return "abc" |
| |
| dag = DAG(dag_id="test", start_date=DEFAULT_DATE, schedule=CustomRunIdTimetable()) |
| |
| dag_run = dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| state=DagRunState.QUEUED, |
| execution_date=DEFAULT_DATE, |
| data_interval=(DEFAULT_DATE, DEFAULT_DATE), |
| ) |
| |
| assert dag_run.run_id == "abc" |
| |
| |
| @pytest.mark.parametrize( |
| "run_id_type", |
| [DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED, DagRunType.DATASET_TRIGGERED], |
| ) |
| def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagRunType) -> None: |
| dag = DAG(dag_id="test", start_date=DEFAULT_DATE, schedule="@daily") |
| run_id = run_id_type.generate_run_id(DEFAULT_DATE) |
| with pytest.raises(ValueError) as ctx: |
| dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| run_id=run_id, |
| execution_date=DEFAULT_DATE, |
| data_interval=(DEFAULT_DATE, DEFAULT_DATE), |
| state=DagRunState.QUEUED, |
| ) |
| assert str(ctx.value) == ( |
| f"A manual DAG run cannot use ID {run_id!r} since it is reserved for {run_id_type.value} runs" |
| ) |
| |
| |
| class TestTaskClearingSetupTeardownBehavior: |
| """ |
| Task clearing behavior is mainly controlled by dag.partial_subset. |
| Here we verify, primarily with regard to setups and teardowns, the |
| behavior of dag.partial_subset but also the supporting methods defined |
| on AbstractOperator. |
| """ |
| |
| @staticmethod |
| def make_tasks(dag, input_str): |
| """ |
| Helper for building setup and teardown tasks for testing. |
| |
| Given an input such as 's1, w1, t1, tf1', returns setup task "s1", normal task "w1" |
| (the w means *work*), teardown task "t1", and teardown task "tf1" where the f means |
| on_failure_fail_dagrun has been set to true. |
| """ |
| |
| def teardown_task(task_id): |
| return BaseOperator(task_id=task_id).as_teardown() |
| |
| def teardown_task_f(task_id): |
| return BaseOperator(task_id=task_id).as_teardown(on_failure_fail_dagrun=True) |
| |
| def work_task(task_id): |
| return BaseOperator(task_id=task_id) |
| |
| def setup_task(task_id): |
| return BaseOperator(task_id=task_id).as_setup() |
| |
| def make_task(task_id): |
| """ |
| Task factory helper. |
| |
| Will give a setup, teardown, work, or teardown-with-dagrun-failure task depending on input. |
| """ |
| if task_id.startswith("s"): |
| factory = setup_task |
| elif task_id.startswith("w"): |
| factory = work_task |
| elif task_id.startswith("tf"): |
| factory = teardown_task_f |
| elif task_id.startswith("t"): |
| factory = teardown_task |
| else: |
| raise ValueError("unexpected") |
| return dag.task_dict.get(task_id) or factory(task_id=task_id) |
| |
| return (make_task(x) for x in input_str.split(", ")) |
| |
| @staticmethod |
| def cleared_downstream(task): |
| """Helper to return tasks that would be cleared if **downstream** selected.""" |
| upstream = False |
| return set( |
| task.dag.partial_subset( |
| task_ids_or_regex=[task.task_id], |
| include_downstream=not upstream, |
| include_upstream=upstream, |
| ).tasks |
| ) |
| |
| @staticmethod |
| def cleared_upstream(task): |
| """Helper to return tasks that would be cleared if **upstream** selected.""" |
| upstream = True |
| return set( |
| task.dag.partial_subset( |
| task_ids_or_regex=task.task_id, |
| include_downstream=not upstream, |
| include_upstream=upstream, |
| ).tasks |
| ) |
| |
| @staticmethod |
| def cleared_neither(task): |
| """Helper to return tasks that would be cleared if **upstream** selected.""" |
| return set( |
| task.dag.partial_subset( |
| task_ids_or_regex=[task.task_id], |
| include_downstream=False, |
| include_upstream=False, |
| ).tasks |
| ) |
| |
| def test_get_flat_relative_ids_with_setup(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, w1, w2, w3, w4, t1 = self.make_tasks(dag, "s1, w1, w2, w3, w4, t1") |
| |
| s1 >> w1 >> w2 >> w3 |
| |
| # w1 is downstream of s1, and s1 has no teardown, so clearing w1 clears s1 |
| assert set(w1.get_upstreams_only_setups_and_teardowns()) == {s1} |
| # same with w2 and w3 |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == {s1} |
| assert set(w3.get_upstreams_only_setups_and_teardowns()) == {s1} |
| # so if we clear w2, we should also get s1, and w3, but not w1 |
| assert self.cleared_downstream(w2) == {s1, w2, w3} |
| |
| w3 >> t1 |
| |
| # now, w2 has a downstream teardown, but it's not connected directly to s1 |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == {s1} |
| # so if we clear downstream then s1 will be cleared, and t1 will be cleared but only by virtue of |
| # being downstream of w2 -- not as a result of being the teardown for s1, which it ain't |
| assert self.cleared_downstream(w2) == {s1, w2, w3, t1} |
| # and, another consequence of not linking s1 and t1 is that when we clear upstream, note that |
| # t1 doesn't get cleared -- cus it's not upstream and it's not linked to s1 |
| assert self.cleared_upstream(w2) == {s1, w1, w2} |
| # note also that if we add a 4th work task after t1, it will still be "in scope" for s1 |
| t1 >> w4 |
| assert self.cleared_downstream(w4) == {s1, w4} |
| |
| s1 >> t1 |
| |
| # now, we know that t1 is the teardown for s1, so now we know that s1 will be "torn down" |
| # by the time w4 runs, so we now know that w4 no longer requires s1, so when we clear w4, |
| # s1 will not also be cleared |
| self.cleared_downstream(w4) == {w4} |
| assert set(w1.get_upstreams_only_setups_and_teardowns()) == {s1, t1} |
| assert self.cleared_downstream(w1) == {s1, w1, w2, w3, t1, w4} |
| assert self.cleared_upstream(w1) == {s1, w1, t1} |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == {s1, t1} |
| assert set(w2.get_upstreams_follow_setups()) == {s1, w1, t1} |
| assert self.cleared_downstream(w2) == {s1, w2, w3, t1, w4} |
| assert self.cleared_upstream(w2) == {s1, w1, w2, t1} |
| assert self.cleared_downstream(w3) == {s1, w3, t1, w4} |
| assert self.cleared_upstream(w3) == {s1, w1, w2, w3, t1} |
| |
| def test_get_flat_relative_ids_with_setup_nested_ctx_mgr(self): |
| """Let's test some gnarlier cases here""" |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2 = self.make_tasks(dag, "s1, t1, s2, t2") |
| with s1 >> t1: |
| BaseOperator(task_id="w1") |
| with s2 >> t2: |
| BaseOperator(task_id="w2") |
| BaseOperator(task_id="w3") |
| # to_do: implement tests |
| |
| def test_get_flat_relative_ids_with_setup_nested_no_ctx_mgr(self): |
| """Let's test some gnarlier cases here""" |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2, w1, w2, w3 = self.make_tasks(dag, "s1, t1, s2, t2, w1, w2, w3") |
| s1 >> t1 |
| s1 >> w1 >> t1 |
| s1 >> s2 |
| s2 >> t2 |
| s2 >> w2 >> w3 >> t2 |
| |
| assert w1.get_flat_relative_ids(upstream=True) == {"s1"} |
| assert w1.get_flat_relative_ids(upstream=False) == {"t1"} |
| assert self.cleared_downstream(w1) == {s1, w1, t1} |
| assert self.cleared_upstream(w1) == {s1, w1, t1} |
| assert w3.get_flat_relative_ids(upstream=True) == {"s1", "s2", "w2"} |
| assert w3.get_flat_relative_ids(upstream=False) == {"t2"} |
| assert t1 not in w2.get_flat_relatives(upstream=False) # t1 not required by w2 |
| # t1 only included because s1 is upstream |
| assert self.cleared_upstream(w2) == {s1, t1, s2, w2, t2} |
| # t1 not included because t1 is not downstream |
| assert self.cleared_downstream(w2) == {s2, w2, w3, t2} |
| # t1 only included because s1 is upstream |
| assert self.cleared_upstream(w3) == {s1, t1, s2, w2, w3, t2} |
| # t1 not included because t1 is not downstream |
| assert self.cleared_downstream(w3) == {s2, w3, t2} |
| |
| def test_get_flat_relative_ids_follows_teardowns(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, w1, w2, t1 = self.make_tasks(dag, "s1, w1, w2, t1") |
| s1 >> w1 >> [w2, t1] |
| s1 >> t1 |
| # w2, we infer, does not require s1, since t1 does not come after it |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == set() |
| # w1, however, *does* require s1, since t1 is downstream of it |
| assert set(w1.get_upstreams_only_setups_and_teardowns()) == {s1, t1} |
| # downstream is just downstream and includes teardowns |
| assert self.cleared_downstream(w1) == {s1, w1, w2, t1} |
| assert self.cleared_downstream(w2) == {w2} |
| # and if there's a downstream setup, it will be included as well |
| s2 = BaseOperator(task_id="s2", dag=dag).as_setup() |
| t1 >> s2 |
| assert w1.get_flat_relative_ids(upstream=False) == {"t1", "w2", "s2"} |
| assert self.cleared_downstream(w1) == {s1, w1, w2, t1, s2} |
| |
| def test_get_flat_relative_ids_two_tasks_diff_setup_teardowns(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2, w1, w2 = self.make_tasks(dag, "s1, t1, s2, t2, w1, w2") |
| s1 >> w1 >> [w2, t1] |
| s1 >> t1 |
| s2 >> t2 |
| s2 >> w2 >> t2 |
| |
| assert set(w1.get_upstreams_only_setups_and_teardowns()) == {s1, t1} |
| # s2 is included because w2 is included |
| assert self.cleared_downstream(w1) == {s1, w1, t1, s2, w2, t2} |
| assert self.cleared_neither(w1) == {s1, w1, t1} |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == {s2, t2} |
| assert self.cleared_downstream(w2) == {s2, w2, t2} |
| |
| def test_get_flat_relative_ids_one_task_multiple_setup_teardowns(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1a, s1b, t1, s2, t2, s3, t3a, t3b, w1, w2 = self.make_tasks( |
| dag, "s1a, s1b, t1, s2, t2, s3, t3a, t3b, w1, w2" |
| ) |
| # teardown t1 has two setups, s1a and s1b |
| [s1a, s1b] >> t1 |
| # work 1 requires s1a and s1b, both of which are torn down by t1 |
| [s1a, s1b] >> w1 >> [w2, t1] |
| |
| # work 2 requires s2, and s3. s2 is torn down by t2. s3 is torn down by two teardowns, t3a and t3b. |
| s2 >> t2 |
| s2 >> w2 >> t2 |
| s3 >> w2 >> [t3a, t3b] |
| s3 >> [t3a, t3b] |
| assert set(w1.get_upstreams_only_setups_and_teardowns()) == {s1a, s1b, t1} |
| # since w2 is downstream of w1, w2 gets cleared. |
| # and since w2 gets cleared, we should also see s2 and s3 in here |
| assert self.cleared_downstream(w1) == {s1a, s1b, w1, t1, s3, t3a, t3b, w2, s2, t2} |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == {s2, t2, s3, t3a, t3b} |
| assert self.cleared_downstream(w2) == {s2, s3, w2, t2, t3a, t3b} |
| |
| def test_get_flat_relative_ids_with_setup_and_groups(self): |
| """This is a dag with a setup / teardown at dag level and two task groups that have |
| their own setups / teardowns. |
| |
| When we do tg >> dag_teardown, teardowns should be excluded from tg leaves. |
| """ |
| dag = DAG(dag_id="test_dag", start_date=pendulum.now()) |
| with dag: |
| dag_setup = BaseOperator(task_id="dag_setup").as_setup() |
| dag_teardown = BaseOperator(task_id="dag_teardown").as_teardown() |
| dag_setup >> dag_teardown |
| for group_name in ("g1", "g2"): |
| with TaskGroup(group_name) as tg: |
| group_setup = BaseOperator(task_id="group_setup").as_setup() |
| w1 = BaseOperator(task_id="w1") |
| w2 = BaseOperator(task_id="w2") |
| w3 = BaseOperator(task_id="w3") |
| group_teardown = BaseOperator(task_id="group_teardown").as_teardown() |
| group_setup >> w1 >> w2 >> w3 >> group_teardown |
| group_setup >> group_teardown |
| dag_setup >> tg >> dag_teardown |
| g2_w2 = dag.task_dict["g2.w2"] |
| g2_w3 = dag.task_dict["g2.w3"] |
| g2_group_teardown = dag.task_dict["g2.group_teardown"] |
| |
| # the line `dag_setup >> tg >> dag_teardown` should be equivalent to |
| # dag_setup >> group_setup; w3 >> dag_teardown |
| # i.e. not group_teardown >> dag_teardown |
| # this way the two teardowns can run in parallel |
| # so first, check that dag_teardown not downstream of group 2 teardown |
| # this means they can run in parallel |
| assert "dag_teardown" not in g2_group_teardown.downstream_task_ids |
| # and just document that g2 teardown is in effect a dag leaf |
| assert g2_group_teardown.downstream_task_ids == set() |
| # group 2 task w3 is in the scope of 2 teardowns -- the dag teardown and the group teardown |
| # it is arrowed to both of them |
| assert g2_w3.downstream_task_ids == {"g2.group_teardown", "dag_teardown"} |
| # dag teardown should have 3 upstreams: the last work task in groups 1 and 2, and its setup |
| assert dag_teardown.upstream_task_ids == {"g1.w3", "g2.w3", "dag_setup"} |
| |
| assert {x.task_id for x in g2_w2.get_upstreams_only_setups_and_teardowns()} == { |
| "dag_setup", |
| "dag_teardown", |
| "g2.group_setup", |
| "g2.group_teardown", |
| } |
| |
| # clearing g2.w2 clears all setups and teardowns and g2.w2 and g2.w2 |
| # but not anything from g1 |
| assert {x.task_id for x in self.cleared_downstream(g2_w2)} == { |
| "dag_setup", |
| "dag_teardown", |
| "g2.group_setup", |
| "g2.group_teardown", |
| "g2.w3", |
| "g2.w2", |
| } |
| assert {x.task_id for x in self.cleared_upstream(g2_w2)} == { |
| "dag_setup", |
| "dag_teardown", |
| "g2.group_setup", |
| "g2.group_teardown", |
| "g2.w1", |
| "g2.w2", |
| } |
| |
| def test_clear_upstream_not_your_setup(self): |
| """ |
| When you have a work task that comes after a setup, then if you clear upstream |
| the setup (and its teardown) will be cleared even though strictly speaking you don't |
| "require" it since, depending on speed of execution, it might be torn down by t1 |
| before / while w2 runs. It just gets cleared by virtue of it being upstream, and |
| that's what you requested. And its teardown gets cleared too. But w1 doesn't. |
| """ |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, w1, w2, t1 = self.make_tasks(dag, "s1, w1, w2, t1") |
| s1 >> w1 >> t1.as_teardown(setups=s1) |
| s1 >> w2 |
| # w2 is downstream of s1, so when clearing upstream, it should clear s1 (since it |
| # is upstream of w2) and t1 since it's the teardown for s1 even though not downstream of w1 |
| assert self.cleared_upstream(w2) == {s1, w2, t1} |
| |
| def test_clearing_teardown_no_clear_setup(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, w1, t1 = self.make_tasks(dag, "s1, w1, t1") |
| s1 >> t1 |
| # clearing t1 does not clear s1 |
| assert self.cleared_downstream(t1) == {t1} |
| s1 >> w1 >> t1 |
| # that isn't changed with the introduction of w1 |
| assert self.cleared_downstream(t1) == {t1} |
| # though, of course, clearing w1 clears them all |
| assert self.cleared_downstream(w1) == {s1, w1, t1} |
| |
| def test_clearing_setup_clears_teardown(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, w1, t1 = self.make_tasks(dag, "s1, w1, t1") |
| s1 >> t1 |
| s1 >> w1 >> t1 |
| # clearing w1 clears all always |
| assert self.cleared_upstream(w1) == {s1, w1, t1} |
| assert self.cleared_downstream(w1) == {s1, w1, t1} |
| assert self.cleared_neither(w1) == {s1, w1, t1} |
| # clearing s1 clears t1 always |
| assert self.cleared_upstream(s1) == {s1, t1} |
| assert self.cleared_downstream(s1) == {s1, w1, t1} |
| assert self.cleared_neither(s1) == {s1, t1} |
| |
| @pytest.mark.parametrize( |
| "upstream, downstream, expected", |
| [ |
| (False, False, {"my_teardown", "my_setup"}), |
| (False, True, {"my_setup", "my_work", "my_teardown"}), |
| (True, False, {"my_teardown", "my_setup"}), |
| (True, True, {"my_setup", "my_work", "my_teardown"}), |
| ], |
| ) |
| def test_clearing_setup_clears_teardown_taskflow(self, upstream, downstream, expected): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| |
| @setup |
| def my_setup(): ... |
| |
| @task_decorator |
| def my_work(): ... |
| |
| @teardown |
| def my_teardown(): ... |
| |
| s1 = my_setup() |
| w1 = my_work() |
| t1 = my_teardown() |
| s1 >> w1 >> t1 |
| s1 >> t1 |
| assert { |
| x.task_id |
| for x in dag.partial_subset( |
| "my_setup", include_upstream=upstream, include_downstream=downstream |
| ).tasks |
| } == expected |
| |
| def test_get_flat_relative_ids_two_tasks_diff_setup_teardowns_deeper(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2, w1, w2, s3, w3, t3 = self.make_tasks(dag, "s1, t1, s2, t2, w1, w2, s3, w3, t3") |
| s1 >> w1 >> t1 |
| s1 >> t1 |
| w1 >> w2 |
| |
| # with the below, s2 is not downstream of w1, but it's the setup for w2 |
| # so it should be cleared when w1 is cleared |
| s2 >> w2 >> t2 |
| s2 >> t2 |
| |
| assert set(w1.get_upstreams_only_setups_and_teardowns()) == {s1, t1} |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == {s2, t2} |
| assert self.cleared_downstream(w1) == {s1, w1, t1, s2, w2, t2} |
| assert self.cleared_downstream(w2) == {s2, w2, t2} |
| |
| # now, what if s2 itself has a setup and teardown? |
| s3 >> s2 >> t3 |
| s3 >> t3 |
| # note that s3 is excluded because it's assumed that a setup won't have a setup |
| # so, we don't continue to recurse for setups after reaching the setups for |
| # the downstream work tasks |
| # but, t3 is included since it's a teardown for s2 |
| assert self.cleared_downstream(w1) == {s1, w1, t1, s2, w2, t2, t3} |
| |
| def test_clearing_behavior_multiple_setups_for_work_task(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2, w1, w2, s3, w3, t3 = self.make_tasks(dag, "s1, t1, s2, t2, w1, w2, s3, w3, t3") |
| s1 >> t1 |
| s2 >> t2 |
| s3 >> t3 |
| s1 >> s2 >> s3 >> w1 >> w2 >> [t1, t2, t3] |
| |
| assert self.cleared_downstream(w1) == {s1, s2, s3, w1, w2, t1, t2, t3} |
| assert self.cleared_downstream(w2) == {s1, s2, s3, w2, t1, t2, t3} |
| assert self.cleared_downstream(s3) == {s1, s2, s3, w1, w2, t1, t2, t3} |
| # even if we don't include upstream / downstream, setups and teardowns are cleared |
| assert self.cleared_neither(w2) == {s3, t3, s2, t2, s1, t1, w2} |
| assert self.cleared_neither(w1) == {s3, t3, s2, t2, s1, t1, w1} |
| # but, a setup doesn't formally have a setup, so if we only clear s3, say then its upstream setups |
| # are not also cleared |
| assert self.cleared_neither(s3) == {s3, t3} |
| assert self.cleared_neither(s2) == {s2, t2} |
| |
| def test_clearing_behavior_multiple_setups_for_work_task2(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2, w1, w2, s3, w3, t3 = self.make_tasks(dag, "s1, t1, s2, t2, w1, w2, s3, w3, t3") |
| s1 >> t1 |
| s2 >> t2 |
| s3 >> t3 |
| [s1, s2, s3] >> w1 >> w2 >> [t1, t2, t3] |
| |
| assert self.cleared_downstream(w1) == {s1, s2, s3, w1, w2, t1, t2, t3} |
| assert self.cleared_downstream(w2) == {s1, s2, s3, w2, t1, t2, t3} |
| |
| def test_clearing_behavior_more_tertiary_weirdness(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2, w1, w2, s3, t3 = self.make_tasks(dag, "s1, t1, s2, t2, w1, w2, s3, t3") |
| s1 >> t1 |
| s2 >> t2 |
| s1 >> w1 >> s2 >> w2 >> [t1, t2] |
| s2 >> w2 >> t2 |
| s3 >> s2 >> t3 |
| s3 >> t3 |
| |
| def sort(task_list): |
| return sorted(x.task_id for x in task_list) |
| |
| assert set(w1.get_upstreams_only_setups_and_teardowns()) == {s1, t1} |
| # s2 is included because w2 is included |
| assert self.cleared_downstream(w1) == {s1, w1, t1, s2, w2, t2, t3} |
| assert self.cleared_downstream(w2) == {s1, t1, s2, w2, t2, t3} |
| # t3 is included since s2 is included and s2 >> t3 |
| # but s3 not included because it's assumed that a setup doesn't have a setup |
| assert self.cleared_neither(w2) == {s1, w2, t1, s2, t2, t3} |
| |
| # since we're clearing upstream, s3 is upstream of w2, so s3 and t3 are included |
| # even though w2 doesn't require them |
| # s2 and t2 are included for obvious reasons, namely that w2 requires s2 |
| # and s1 and t1 are included for the same reason |
| # w1 included since it is upstream of w2 |
| assert sort(self.cleared_upstream(w2)) == sort({s1, t1, s2, t2, s3, t3, w1, w2}) |
| |
| # t3 is included here since it's a teardown for s2 |
| assert set(w2.get_upstreams_only_setups_and_teardowns()) == {s2, t2, s1, t1, t3} |
| |
| def test_clearing_behavior_more_tertiary_weirdness2(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1, s2, t2, w1, w2, s3, t3 = self.make_tasks(dag, "s1, t1, s2, t2, w1, w2, s3, t3") |
| s1 >> t1 |
| s2 >> t2 |
| s1 >> w1 >> t1 |
| s2 >> t1 >> t2 |
| |
| def sort(task_list): |
| return sorted(x.task_id for x in task_list) |
| |
| # t2 included since downstream, but s2 not included since it's not required by t2 |
| # and clearing teardown does not clear the setup |
| assert self.cleared_downstream(w1) == {s1, w1, t1, t2} |
| |
| # even though t1 is cleared here, s2 and t2 are not "setup and teardown" for t1 |
| # so they are not included |
| assert self.cleared_neither(w1) == {s1, w1, t1} |
| assert self.cleared_upstream(w1) == {s1, w1, t1} |
| |
| # t1 does not have a setup or teardown |
| # but t2 is downstream so it's included |
| # and s2 is not included since clearing teardown does not clear the setup |
| assert self.cleared_downstream(t1) == {t1, t2} |
| # t1 does not have a setup or teardown |
| assert self.cleared_neither(t1) == {t1} |
| # s2 included since upstream, and t2 included since s2 included |
| assert self.cleared_upstream(t1) == {s1, t1, s2, t2, w1} |
| |
| def test_clearing_behavior_just_teardown(self): |
| with DAG(dag_id="test_dag", start_date=pendulum.now()) as dag: |
| s1, t1 = self.make_tasks(dag, "s1, t1") |
| s1 >> t1 |
| assert set(t1.get_upstreams_only_setups_and_teardowns()) == set() |
| assert self.cleared_upstream(t1) == {s1, t1} |
| assert self.cleared_downstream(t1) == {t1} |
| assert self.cleared_neither(t1) == {t1} |
| assert set(s1.get_upstreams_only_setups_and_teardowns()) == set() |
| assert self.cleared_upstream(s1) == {s1, t1} |
| assert self.cleared_downstream(s1) == {s1, t1} |
| assert self.cleared_neither(s1) == {s1, t1} |
| |
| def test_validate_setup_teardown_trigger_rule(self): |
| with DAG( |
| dag_id="direct_setup_trigger_rule", start_date=pendulum.now(), schedule=None, catchup=False |
| ) as dag: |
| s1, w1 = self.make_tasks(dag, "s1, w1") |
| s1 >> w1 |
| dag.validate_setup_teardown() |
| w1.trigger_rule = TriggerRule.ONE_FAILED |
| with pytest.raises( |
| Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS." |
| ): |
| dag.validate_setup_teardown() |
| |
| |
| def test_statement_latest_runs_one_dag(): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("error", category=SAWarning) |
| |
| stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"]) |
| compiled_stmt = str(stmt.compile()) |
| actual = [x.strip() for x in compiled_stmt.splitlines()] |
| expected = [ |
| "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", |
| "FROM dag_run", |
| "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT max(dag_run.execution_date) AS max_execution_date", |
| "FROM dag_run", |
| "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))", |
| ] |
| assert actual == expected, compiled_stmt |
| |
| |
| def test_statement_latest_runs_many_dag(): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("error", category=SAWarning) |
| |
| stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"]) |
| compiled_stmt = str(stmt.compile()) |
| actual = [x.strip() for x in compiled_stmt.splitlines()] |
| expected = [ |
| "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", |
| "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.execution_date) AS max_execution_date", |
| "FROM dag_run", |
| "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1", |
| "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = anon_1.max_execution_date", |
| ] |
| assert actual == expected, compiled_stmt |