blob: e56f91e71e89126230f922289ad2412bc6797c58 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import datetime
import io
import logging
import os
import pickle
import re
import sys
import weakref
from contextlib import redirect_stdout
from datetime import timedelta
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import mock
from unittest.mock import patch
import jinja2
import pendulum
import pytest
import time_machine
from dateutil.relativedelta import relativedelta
from sqlalchemy import inspect
import airflow
from airflow import models, settings
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.decorators import setup, task as task_decorator, teardown
from airflow.exceptions import (
AirflowException,
DuplicateTaskIdFound,
ParamValidationError,
RemovedInAirflow3Warning,
)
from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail, TaskInstance as TI
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DagOwnerAttributes, dag as dag_decorator, get_dataset_triggered_next_run_info
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel, TaskOutletDatasetReference
from airflow.models.param import DagParam, Param, ParamsDict
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.operators.subdag import SubDagOperator
from airflow.security import permissions
from airflow.templates import NativeEnvironment, SandboxedEnvironment
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.simple import (
ContinuousTimetable,
DatasetTriggeredTimetable,
NullTimetable,
OnceTimetable,
)
from airflow.utils import timezone
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import TaskGroup, TaskGroupContext
from airflow.utils.timezone import datetime as datetime_tz
from airflow.utils.types import DagRunType
from airflow.utils.weight_rule import WeightRule
from tests.models import DEFAULT_DATE
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, clear_db_serialized_dags
from tests.test_utils.mapping import expand_mapped_task
from tests.test_utils.timetables import cron_timetable, delta_timetable
TEST_DATE = datetime_tz(2015, 1, 2, 0, 0)
repo_root = Path(airflow.__file__).parent.parent
@pytest.fixture
def clear_dags():
clear_db_dags()
clear_db_serialized_dags()
yield
clear_db_dags()
clear_db_serialized_dags()
@pytest.fixture
def clear_datasets():
clear_db_datasets()
yield
clear_db_datasets()
class TestDag:
def setup_method(self) -> None:
clear_db_runs()
clear_db_dags()
clear_db_datasets()
self.patcher_dag_code = mock.patch("airflow.models.dag.DagCode.bulk_sync_to_db")
self.patcher_dag_code.start()
def teardown_method(self) -> None:
clear_db_runs()
clear_db_dags()
clear_db_datasets()
self.patcher_dag_code.stop()
@staticmethod
def _clean_up(dag_id: str):
with create_session() as session:
session.query(DagRun).filter(DagRun.dag_id == dag_id).delete(synchronize_session=False)
session.query(TI).filter(TI.dag_id == dag_id).delete(synchronize_session=False)
session.query(TaskFail).filter(TaskFail.dag_id == dag_id).delete(synchronize_session=False)
@staticmethod
def _occur_before(a, b, list_):
"""
Assert that a occurs before b in the list.
"""
a_index = -1
b_index = -1
for i, e in enumerate(list_):
if e.task_id == a:
a_index = i
if e.task_id == b:
b_index = i
return 0 <= a_index < b_index
def test_params_not_passed_is_empty_dict(self):
"""
Test that when 'params' is _not_ passed to a new Dag, that the params
attribute is set to an empty dictionary.
"""
dag = models.DAG("test-dag")
assert isinstance(dag.params, ParamsDict)
assert 0 == len(dag.params)
def test_params_passed_and_params_in_default_args_no_override(self):
"""
Test that when 'params' exists as a key passed to the default_args dict
in addition to params being passed explicitly as an argument to the
dag, that the 'params' key of the default_args dict is merged with the
dict of the params argument.
"""
params1 = {"parameter1": 1}
params2 = {"parameter2": 2}
dag = models.DAG("test-dag", default_args={"params": params1}, params=params2)
assert params1["parameter1"] == dag.params["parameter1"]
assert params2["parameter2"] == dag.params["parameter2"]
def test_not_none_schedule_with_non_default_params(self):
"""
Test if there is a DAG with not None schedule_interval and have some params that
don't have a default value raise a error while DAG parsing
"""
params = {"param1": Param(type="string")}
with pytest.raises(AirflowException):
models.DAG("dummy-dag", params=params)
def test_dag_invalid_default_view(self):
"""
Test invalid `default_view` of DAG initialization
"""
with pytest.raises(AirflowException, match="Invalid values of dag.default_view: only support"):
models.DAG(dag_id="test-invalid-default_view", default_view="airflow")
def test_dag_default_view_default_value(self):
"""
Test `default_view` default value of DAG initialization
"""
dag = models.DAG(dag_id="test-default_default_view")
assert conf.get("webserver", "dag_default_view").lower() == dag.default_view
def test_dag_invalid_orientation(self):
"""
Test invalid `orientation` of DAG initialization
"""
with pytest.raises(AirflowException, match="Invalid values of dag.orientation: only support"):
models.DAG(dag_id="test-invalid-orientation", orientation="airflow")
def test_dag_orientation_default_value(self):
"""
Test `orientation` default value of DAG initialization
"""
dag = models.DAG(dag_id="test-default_orientation")
assert conf.get("webserver", "dag_orientation") == dag.orientation
def test_dag_as_context_manager(self):
"""
Test DAG as a context manager.
When used as a context manager, Operators are automatically added to
the DAG (unless they specify a different DAG)
"""
dag = DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"})
dag2 = DAG("dag2", start_date=DEFAULT_DATE, default_args={"owner": "owner2"})
with dag:
op1 = EmptyOperator(task_id="op1")
op2 = EmptyOperator(task_id="op2", dag=dag2)
assert op1.dag is dag
assert op1.owner == "owner1"
assert op2.dag is dag2
assert op2.owner == "owner2"
with dag2:
op3 = EmptyOperator(task_id="op3")
assert op3.dag is dag2
assert op3.owner == "owner2"
with dag:
with dag2:
op4 = EmptyOperator(task_id="op4")
op5 = EmptyOperator(task_id="op5")
assert op4.dag is dag2
assert op5.dag is dag
assert op4.owner == "owner2"
assert op5.owner == "owner1"
with DAG("creating_dag_in_cm", start_date=DEFAULT_DATE) as dag:
EmptyOperator(task_id="op6")
assert dag.dag_id == "creating_dag_in_cm"
assert dag.tasks[0].task_id == "op6"
with dag:
with dag:
op7 = EmptyOperator(task_id="op7")
op8 = EmptyOperator(task_id="op8")
op9 = EmptyOperator(task_id="op8")
op9.dag = dag2
assert op7.dag == dag
assert op8.dag == dag
assert op9.dag == dag2
def test_dag_topological_sort_include_subdag_tasks(self):
child_dag = DAG(
"parent_dag.child_dag",
schedule="@daily",
start_date=DEFAULT_DATE,
)
with child_dag:
EmptyOperator(task_id="a_child")
EmptyOperator(task_id="b_child")
parent_dag = DAG(
"parent_dag",
schedule="@daily",
start_date=DEFAULT_DATE,
)
# a_parent -> child_dag -> (a_child | b_child) -> b_parent
with parent_dag:
op1 = EmptyOperator(task_id="a_parent")
op2 = SubDagOperator(task_id="child_dag", subdag=child_dag)
op3 = EmptyOperator(task_id="b_parent")
op1 >> op2 >> op3
topological_list = parent_dag.topological_sort(include_subdag_tasks=True)
assert self._occur_before("a_parent", "child_dag", topological_list)
assert self._occur_before("child_dag", "a_child", topological_list)
assert self._occur_before("child_dag", "b_child", topological_list)
assert self._occur_before("a_child", "b_parent", topological_list)
assert self._occur_before("b_child", "b_parent", topological_list)
def test_dag_topological_sort_dag_without_tasks(self):
dag = DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"})
assert () == dag.topological_sort()
def test_dag_naive_start_date_string(self):
DAG("DAG", default_args={"start_date": "2019-06-01"})
def test_dag_naive_start_end_dates_strings(self):
DAG("DAG", default_args={"start_date": "2019-06-01", "end_date": "2019-06-05"})
def test_dag_start_date_propagates_to_end_date(self):
"""
Tests that a start_date string with a timezone and an end_date string without a timezone
are accepted and that the timezone from the start carries over the end
This test is a little indirect, it works by setting start and end equal except for the
timezone and then testing for equality after the DAG construction. They'll be equal
only if the same timezone was applied to both.
An explicit check the `tzinfo` attributes for both are the same is an extra check.
"""
dag = DAG(
"DAG", default_args={"start_date": "2019-06-05T00:00:00+05:00", "end_date": "2019-06-05T00:00:00"}
)
assert dag.default_args["start_date"] == dag.default_args["end_date"]
assert dag.default_args["start_date"].tzinfo == dag.default_args["end_date"].tzinfo
def test_dag_naive_default_args_start_date(self):
dag = DAG("DAG", default_args={"start_date": datetime.datetime(2018, 1, 1)})
assert dag.timezone == settings.TIMEZONE
dag = DAG("DAG", start_date=datetime.datetime(2018, 1, 1))
assert dag.timezone == settings.TIMEZONE
def test_dag_none_default_args_start_date(self):
"""
Tests if a start_date of None in default_args
works.
"""
dag = DAG("DAG", default_args={"start_date": None})
assert dag.timezone == settings.TIMEZONE
def test_dag_task_priority_weight_total(self):
width = 5
depth = 5
weight = 5
pattern = re.compile("stage(\\d*).(\\d*)")
# Fully connected parallel tasks. i.e. every task at each parallel
# stage is dependent on every task in the previous stage.
# Default weight should be calculated using downstream descendants
with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag:
pipeline = [
[EmptyOperator(task_id=f"stage{i}.{j}", priority_weight=weight) for j in range(0, width)]
for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for task in dag.task_dict.values():
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
calculated_weight = task.priority_weight_total
assert calculated_weight == correct_weight
def test_dag_task_priority_weight_total_using_upstream(self):
# Same test as above except use 'upstream' for weight calculation
weight = 3
width = 5
depth = 5
pattern = re.compile("stage(\\d*).(\\d*)")
with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag:
pipeline = [
[
EmptyOperator(
task_id=f"stage{i}.{j}",
priority_weight=weight,
weight_rule=WeightRule.UPSTREAM,
)
for j in range(0, width)
]
for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for task in dag.task_dict.values():
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = (task_depth * width + 1) * weight
calculated_weight = task.priority_weight_total
assert calculated_weight == correct_weight
def test_dag_task_priority_weight_total_using_absolute(self):
# Same test as above except use 'absolute' for weight calculation
weight = 10
width = 5
depth = 5
with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag:
pipeline = [
[
EmptyOperator(
task_id=f"stage{i}.{j}",
priority_weight=weight,
weight_rule=WeightRule.ABSOLUTE,
)
for j in range(0, width)
]
for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for task in dag.task_dict.values():
# the sum of each stages after this task + itself
correct_weight = weight
calculated_weight = task.priority_weight_total
assert calculated_weight == correct_weight
def test_dag_task_invalid_weight_rule(self):
# Test if we enter an invalid weight rule
with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}):
with pytest.raises(AirflowException):
EmptyOperator(task_id="should_fail", weight_rule="no rule")
def test_get_num_task_instances(self):
test_dag_id = "test_get_num_task_instances_dag"
test_task_id = "task_1"
test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE)
test_task = EmptyOperator(task_id=test_task_id, dag=test_dag)
dr1 = test_dag.create_dagrun(state=None, run_id="test1", execution_date=DEFAULT_DATE)
dr2 = test_dag.create_dagrun(
state=None, run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(days=1)
)
dr3 = test_dag.create_dagrun(
state=None, run_id="test3", execution_date=DEFAULT_DATE + datetime.timedelta(days=2)
)
dr4 = test_dag.create_dagrun(
state=None, run_id="test4", execution_date=DEFAULT_DATE + datetime.timedelta(days=3)
)
ti1 = TI(task=test_task, run_id=dr1.run_id)
ti1.state = None
ti2 = TI(task=test_task, run_id=dr2.run_id)
ti2.state = State.RUNNING
ti3 = TI(task=test_task, run_id=dr3.run_id)
ti3.state = State.QUEUED
ti4 = TI(task=test_task, run_id=dr4.run_id)
ti4.state = State.RUNNING
session = settings.Session()
session.merge(ti1)
session.merge(ti2)
session.merge(ti3)
session.merge(ti4)
session.commit()
assert 0 == DAG.get_num_task_instances(test_dag_id, task_ids=["fakename"], session=session)
assert 4 == DAG.get_num_task_instances(test_dag_id, task_ids=[test_task_id], session=session)
assert 4 == DAG.get_num_task_instances(
test_dag_id, task_ids=["fakename", test_task_id], session=session
)
assert 1 == DAG.get_num_task_instances(
test_dag_id, task_ids=[test_task_id], states=[None], session=session
)
assert 2 == DAG.get_num_task_instances(
test_dag_id, task_ids=[test_task_id], states=[State.RUNNING], session=session
)
assert 3 == DAG.get_num_task_instances(
test_dag_id, task_ids=[test_task_id], states=[None, State.RUNNING], session=session
)
assert 4 == DAG.get_num_task_instances(
test_dag_id, task_ids=[test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session
)
session.close()
def test_get_task_instances_before(self):
BASE_DATE = timezone.datetime(2022, 7, 20, 20)
test_dag_id = "test_get_task_instances_before"
test_task_id = "the_task"
test_dag = DAG(dag_id=test_dag_id, start_date=BASE_DATE)
EmptyOperator(task_id=test_task_id, dag=test_dag)
session = settings.Session()
def dag_run_before(delta_h=0, type=DagRunType.SCHEDULED):
dagrun = test_dag.create_dagrun(
state=State.SUCCESS, run_type=type, run_id=f"test_{delta_h}", session=session
)
dagrun.start_date = BASE_DATE + timedelta(hours=delta_h)
dagrun.execution_date = BASE_DATE + timedelta(hours=delta_h)
return dagrun
dr1 = dag_run_before(delta_h=-1, type=DagRunType.MANUAL) # H19
dr2 = dag_run_before(delta_h=-2, type=DagRunType.MANUAL) # H18
dr3 = dag_run_before(delta_h=-3, type=DagRunType.MANUAL) # H17
dr4 = dag_run_before(delta_h=-4, type=DagRunType.MANUAL) # H16
dr5 = dag_run_before(delta_h=-5) # H15
dr6 = dag_run_before(delta_h=-6) # H14
dr7 = dag_run_before(delta_h=-7) # H13
dr8 = dag_run_before(delta_h=-8) # H12
session.commit()
REF_DATE = BASE_DATE
assert set([dr.run_id for dr in [dr1]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session)
]
)
assert set([dr.run_id for dr in [dr1, dr2, dr3]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=3, session=session)
]
)
assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=5, session=session)
]
)
assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=7, session=session)
]
)
assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7, dr8]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=9, session=session)
]
)
assert set([dr.run_id for dr in [dr1, dr2, dr3, dr4, dr5, dr6, dr7, dr8]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=10, session=session)
]
) # stays constrained to available ones
REF_DATE = BASE_DATE + timedelta(hours=-3.5)
assert set([dr.run_id for dr in [dr4]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session)
]
)
assert set([dr.run_id for dr in [dr4, dr5, dr6]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=3, session=session)
]
)
assert set([dr.run_id for dr in [dr4, dr5, dr6, dr7, dr8]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=5, session=session)
]
)
assert set([dr.run_id for dr in [dr4, dr5, dr6, dr7, dr8]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=6, session=session)
]
) # stays constrained to available ones
REF_DATE = BASE_DATE + timedelta(hours=-8)
assert set([dr.run_id for dr in [dr8]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=0, session=session)
]
)
assert set([dr.run_id for dr in [dr8]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=1, session=session)
]
)
assert set([dr.run_id for dr in [dr8]]) == set(
[
ti.run_id
for ti in test_dag.get_task_instances_before(base_date=REF_DATE, num=10, session=session)
]
)
session.close()
def test_user_defined_filters_macros(self):
def jinja_udf(name):
return f"Hello {name}"
dag = models.DAG(
"test-dag",
start_date=DEFAULT_DATE,
user_defined_filters={"hello": jinja_udf},
user_defined_macros={"foo": "bar"},
)
jinja_env = dag.get_template_env()
assert "hello" in jinja_env.filters
assert jinja_env.filters["hello"] == jinja_udf
assert jinja_env.globals["foo"] == "bar"
def test_set_jinja_env_additional_option(self):
dag = DAG("test-dag", jinja_environment_kwargs={"keep_trailing_newline": True, "cache_size": 50})
jinja_env = dag.get_template_env()
assert jinja_env.keep_trailing_newline is True
assert jinja_env.cache.capacity == 50
assert jinja_env.undefined is jinja2.StrictUndefined
def test_template_undefined(self):
dag = DAG("test-dag", template_undefined=jinja2.Undefined)
jinja_env = dag.get_template_env()
assert jinja_env.undefined is jinja2.Undefined
@pytest.mark.parametrize(
"use_native_obj, force_sandboxed, expected_env",
[
(False, True, SandboxedEnvironment),
(False, False, SandboxedEnvironment),
(True, False, NativeEnvironment),
(True, True, SandboxedEnvironment),
],
)
def test_template_env(self, use_native_obj, force_sandboxed, expected_env):
dag = DAG("test-dag", render_template_as_native_obj=use_native_obj)
jinja_env = dag.get_template_env(force_sandboxed=force_sandboxed)
assert isinstance(jinja_env, expected_env)
def test_resolve_template_files_value(self):
with NamedTemporaryFile(suffix=".template") as f:
f.write(b"{{ ds }}")
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)
with DAG("test-dag", start_date=DEFAULT_DATE, template_searchpath=template_dir):
task = EmptyOperator(task_id="op1")
task.test_field = template_file
task.template_fields = ("test_field",)
task.template_ext = (".template",)
task.resolve_template_files()
assert task.test_field == "{{ ds }}"
def test_resolve_template_files_list(self):
with NamedTemporaryFile(suffix=".template") as f:
f.write(b"{{ ds }}")
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)
with DAG("test-dag", start_date=DEFAULT_DATE, template_searchpath=template_dir):
task = EmptyOperator(task_id="op1")
task.test_field = [template_file, "some_string"]
task.template_fields = ("test_field",)
task.template_ext = (".template",)
task.resolve_template_files()
assert task.test_field == ["{{ ds }}", "some_string"]
def test_following_previous_schedule(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone("Europe/Zurich")
start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55), dst_rule=pendulum.PRE_TRANSITION)
assert start.isoformat() == "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST"
utc = timezone.convert_to_utc(start)
assert utc.isoformat() == "2018-10-28T00:55:00+00:00", "Pre-condition: correct DST->UTC conversion"
dag = DAG("tz_dag", start_date=start, schedule="*/5 * * * *")
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
assert _next.isoformat() == "2018-10-28T01:00:00+00:00"
assert next_local.isoformat() == "2018-10-28T02:00:00+01:00"
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-28T02:50:00+02:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-28T02:55:00+02:00"
assert prev == utc
def test_following_previous_schedule_daily_dag_cest_to_cet(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone("Europe/Zurich")
start = local_tz.convert(datetime.datetime(2018, 10, 27, 3), dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
dag = DAG("tz_dag", start_date=start, schedule="0 3 * * *")
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-26T03:00:00+02:00"
assert prev.isoformat() == "2018-10-26T01:00:00+00:00"
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
assert next_local.isoformat() == "2018-10-28T03:00:00+01:00"
assert _next.isoformat() == "2018-10-28T02:00:00+00:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-27T03:00:00+02:00"
assert prev.isoformat() == "2018-10-27T01:00:00+00:00"
def test_following_previous_schedule_daily_dag_cet_to_cest(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone("Europe/Zurich")
start = local_tz.convert(datetime.datetime(2018, 3, 25, 2), dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
dag = DAG("tz_dag", start_date=start, schedule="0 3 * * *")
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
assert next_local.isoformat() == "2018-03-25T03:00:00+02:00"
assert _next.isoformat() == "2018-03-25T01:00:00+00:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
def test_following_schedule_relativedelta(self):
"""
Tests following_schedule a dag with a relativedelta schedule
"""
dag_id = "test_schedule_dag_relativedelta"
delta = relativedelta(hours=+1)
dag = DAG(dag_id=dag_id, schedule=delta)
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE))
_next = dag.following_schedule(TEST_DATE)
assert _next.isoformat() == "2015-01-02T01:00:00+00:00"
_next = dag.following_schedule(_next)
assert _next.isoformat() == "2015-01-02T02:00:00+00:00"
def test_following_schedule_relativedelta_with_deprecated_schedule_interval(self):
"""
Tests following_schedule a dag with a relativedelta schedule_interval
"""
dag_id = "test_schedule_dag_relativedelta"
delta = relativedelta(hours=+1)
dag = DAG(dag_id=dag_id, schedule_interval=delta)
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE))
_next = dag.following_schedule(TEST_DATE)
assert _next.isoformat() == "2015-01-02T01:00:00+00:00"
_next = dag.following_schedule(_next)
assert _next.isoformat() == "2015-01-02T02:00:00+00:00"
def test_following_schedule_relativedelta_with_depr_schedule_interval_decorated_dag(self):
"""
Tests following_schedule a dag with a relativedelta schedule_interval
using decorated dag
"""
from airflow.decorators import dag
dag_id = "test_schedule_dag_relativedelta"
delta = relativedelta(hours=+1)
@dag(dag_id=dag_id, schedule_interval=delta)
def mydag():
BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE)
_dag = mydag()
_next = _dag.following_schedule(TEST_DATE)
assert _next.isoformat() == "2015-01-02T01:00:00+00:00"
_next = _dag.following_schedule(_next)
assert _next.isoformat() == "2015-01-02T02:00:00+00:00"
def test_previous_schedule_datetime_timezone(self):
# Check that we don't get an AttributeError 'name' for self.timezone
start = datetime.datetime(2018, 3, 25, 2, tzinfo=datetime.timezone.utc)
dag = DAG("tz_dag", start_date=start, schedule="@hourly")
when = dag.previous_schedule(start)
assert when.isoformat() == "2018-03-25T01:00:00+00:00"
def test_following_schedule_datetime_timezone(self):
# Check that we don't get an AttributeError 'name' for self.timezone
start = datetime.datetime(2018, 3, 25, 2, tzinfo=datetime.timezone.utc)
dag = DAG("tz_dag", start_date=start, schedule="@hourly")
when = dag.following_schedule(start)
assert when.isoformat() == "2018-03-25T03:00:00+00:00"
def test_following_schedule_datetime_timezone_utc0530(self):
# Check that we don't get an AttributeError 'name' for self.timezone
class UTC0530(datetime.tzinfo):
"""tzinfo derived concrete class named "+0530" with offset of 19800"""
# can be configured here
_offset = datetime.timedelta(seconds=19800)
_dst = datetime.timedelta(0)
_name = "+0530"
def utcoffset(self, dt):
return self.__class__._offset
def dst(self, dt):
return self.__class__._dst
def tzname(self, dt):
return self.__class__._name
start = datetime.datetime(2018, 3, 25, 10, tzinfo=UTC0530())
dag = DAG("tz_dag", start_date=start, schedule="@hourly")
when = dag.following_schedule(start)
assert when.isoformat() == "2018-03-25T05:30:00+00:00"
def test_dagtag_repr(self):
clear_db_dags()
dag = DAG("dag-test-dagtag", start_date=DEFAULT_DATE, tags=["tag-1", "tag-2"])
dag.sync_to_db()
with create_session() as session:
assert {"tag-1", "tag-2"} == {
repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == "dag-test-dagtag").all()
}
def test_bulk_write_to_db(self):
clear_db_dags()
dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)]
with assert_queries_count(5):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert {
("dag-bulk-sync-0", "test-dag"),
("dag-bulk-sync-1", "test-dag"),
("dag-bulk-sync-2", "test-dag"),
("dag-bulk-sync-3", "test-dag"),
} == set(session.query(DagTag.dag_id, DagTag.name).all())
for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None
# Re-sync should do fewer queries
with assert_queries_count(8):
DAG.bulk_write_to_db(dags)
with assert_queries_count(8):
DAG.bulk_write_to_db(dags)
# Adding tags
for dag in dags:
dag.tags.append("test-dag2")
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert {
("dag-bulk-sync-0", "test-dag"),
("dag-bulk-sync-0", "test-dag2"),
("dag-bulk-sync-1", "test-dag"),
("dag-bulk-sync-1", "test-dag2"),
("dag-bulk-sync-2", "test-dag"),
("dag-bulk-sync-2", "test-dag2"),
("dag-bulk-sync-3", "test-dag"),
("dag-bulk-sync-3", "test-dag2"),
} == set(session.query(DagTag.dag_id, DagTag.name).all())
# Removing tags
for dag in dags:
dag.tags.remove("test-dag")
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert {
("dag-bulk-sync-0", "test-dag2"),
("dag-bulk-sync-1", "test-dag2"),
("dag-bulk-sync-2", "test-dag2"),
("dag-bulk-sync-3", "test-dag2"),
} == set(session.query(DagTag.dag_id, DagTag.name).all())
for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None
# Removing all tags
for dag in dags:
dag.tags = None
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert not set(session.query(DagTag.dag_id, DagTag.name).all())
for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None
@pytest.mark.parametrize("state", [DagRunState.RUNNING, DagRunState.QUEUED])
def test_bulk_write_to_db_max_active_runs(self, state):
"""
Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max
active runs being hit.
"""
dag = DAG(dag_id="test_scheduler_verify_max_active_runs", start_date=DEFAULT_DATE)
dag.max_active_runs = 1
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
session = settings.Session()
dag.clear()
DAG.bulk_write_to_db([dag], session=session)
model = session.get(DagModel, dag.dag_id)
assert model.next_dagrun == DEFAULT_DATE
assert model.next_dagrun_create_after == DEFAULT_DATE + timedelta(days=1)
dr = dag.create_dagrun(
state=state,
execution_date=model.next_dagrun,
run_type=DagRunType.SCHEDULED,
session=session,
)
assert dr is not None
DAG.bulk_write_to_db([dag])
model = session.get(DagModel, dag.dag_id)
# We signal "at max active runs" by saying this run is never eligible to be created
assert model.next_dagrun_create_after is None
# test that bulk_write_to_db again doesn't update next_dagrun_create_after
DAG.bulk_write_to_db([dag])
model = session.get(DagModel, dag.dag_id)
assert model.next_dagrun_create_after is None
def test_bulk_write_to_db_has_import_error(self):
"""
Test that DagModel.has_import_error is set to false if no import errors.
"""
dag = DAG(dag_id="test_has_import_error", start_date=DEFAULT_DATE)
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
session = settings.Session()
dag.clear()
DAG.bulk_write_to_db([dag], session=session)
model = session.get(DagModel, dag.dag_id)
assert not model.has_import_errors
# Simulate Dagfileprocessor setting the import error to true
model.has_import_errors = True
session.merge(model)
session.flush()
model = session.get(DagModel, dag.dag_id)
# assert
assert model.has_import_errors
# parse
DAG.bulk_write_to_db([dag])
model = session.get(DagModel, dag.dag_id)
# assert that has_import_error is now false
assert not model.has_import_errors
session.close()
def test_bulk_write_to_db_datasets(self):
"""
Ensure that datasets referenced in a dag are correctly loaded into the database.
"""
dag_id1 = "test_dataset_dag1"
dag_id2 = "test_dataset_dag2"
task_id = "test_dataset_task"
uri1 = "s3://dataset1"
d1 = Dataset(uri1, extra={"not": "used"})
d2 = Dataset("s3://dataset2")
d3 = Dataset("s3://dataset3")
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[d1])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2, d3])
dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE)
EmptyOperator(task_id=task_id, dag=dag2, outlets=[Dataset(uri1, extra={"should": "be used"})])
session = settings.Session()
dag1.clear()
DAG.bulk_write_to_db([dag1, dag2], session=session)
session.commit()
stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()}
d1_orm = stored_datasets[d1.uri]
d2_orm = stored_datasets[d2.uri]
d3_orm = stored_datasets[d3.uri]
assert stored_datasets[uri1].extra == {"should": "be used"}
assert [x.dag_id for x in d1_orm.consuming_dags] == [dag_id1]
assert [(x.task_id, x.dag_id) for x in d1_orm.producing_tasks] == [(task_id, dag_id2)]
assert set(
session.query(
TaskOutletDatasetReference.task_id,
TaskOutletDatasetReference.dag_id,
TaskOutletDatasetReference.dataset_id,
)
.filter(TaskOutletDatasetReference.dag_id.in_((dag_id1, dag_id2)))
.all()
) == {
(task_id, dag_id1, d2_orm.id),
(task_id, dag_id1, d3_orm.id),
(task_id, dag_id2, d1_orm.id),
}
# now that we have verified that a new dag has its dataset references recorded properly,
# we need to verify that *changes* are recorded properly.
# so if any references are *removed*, they should also be deleted from the DB
# so let's remove some references and see what happens
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=None)
EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2])
dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE)
EmptyOperator(task_id=task_id, dag=dag2)
DAG.bulk_write_to_db([dag1, dag2], session=session)
session.commit()
session.expunge_all()
stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()}
d1_orm = stored_datasets[d1.uri]
d2_orm = stored_datasets[d2.uri]
assert [x.dag_id for x in d1_orm.consuming_dags] == []
assert set(
session.query(
TaskOutletDatasetReference.task_id,
TaskOutletDatasetReference.dag_id,
TaskOutletDatasetReference.dataset_id,
)
.filter(TaskOutletDatasetReference.dag_id.in_((dag_id1, dag_id2)))
.all()
) == {(task_id, dag_id1, d2_orm.id)}
def test_bulk_write_to_db_unorphan_datasets(self):
"""
Datasets can lose their last reference and be orphaned, but then if a reference to them reappears, we
need to un-orphan those datasets
"""
with create_session() as session:
# Create four datasets - two that have references and two that are unreferenced and marked as
# orphans
dataset1 = Dataset(uri="ds1")
dataset2 = Dataset(uri="ds2")
session.add(DatasetModel(uri=dataset2.uri, is_orphaned=True))
dataset3 = Dataset(uri="ds3")
dataset4 = Dataset(uri="ds4")
session.add(DatasetModel(uri=dataset4.uri, is_orphaned=True))
session.flush()
dag1 = DAG(dag_id="datasets-1", start_date=DEFAULT_DATE, schedule=[dataset1])
BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[dataset3])
DAG.bulk_write_to_db([dag1], session=session)
# Double check
non_orphaned_datasets = [
dataset.uri
for dataset in session.query(DatasetModel.uri)
.filter(~DatasetModel.is_orphaned)
.order_by(DatasetModel.uri)
]
assert non_orphaned_datasets == ["ds1", "ds3"]
orphaned_datasets = [
dataset.uri
for dataset in session.query(DatasetModel.uri)
.filter(DatasetModel.is_orphaned)
.order_by(DatasetModel.uri)
]
assert orphaned_datasets == ["ds2", "ds4"]
# Now add references to the two unreferenced datasets
dag1 = DAG(dag_id="datasets-1", start_date=DEFAULT_DATE, schedule=[dataset1, dataset2])
BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[dataset3, dataset4])
DAG.bulk_write_to_db([dag1], session=session)
# and count the orphans and non-orphans
non_orphaned_dataset_count = session.query(DatasetModel).filter(~DatasetModel.is_orphaned).count()
assert non_orphaned_dataset_count == 4
orphaned_dataset_count = session.query(DatasetModel).filter(DatasetModel.is_orphaned).count()
assert orphaned_dataset_count == 0
def test_sync_to_db(self):
dag = DAG(
"dag",
start_date=DEFAULT_DATE,
)
with dag:
EmptyOperator(task_id="task", owner="owner1")
subdag = DAG(
"dag.subtask",
start_date=DEFAULT_DATE,
)
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
SubDagOperator(task_id="subtask", owner="owner2", subdag=subdag)
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == "dag").one()
assert set(orm_dag.owners.split(", ")) == {"owner1", "owner2"}
assert orm_dag.is_active
assert orm_dag.default_view is not None
assert orm_dag.default_view == conf.get("webserver", "dag_default_view").lower()
assert orm_dag.safe_dag_id == "dag"
orm_subdag = session.query(DagModel).filter(DagModel.dag_id == "dag.subtask").one()
assert set(orm_subdag.owners.split(", ")) == {"owner1", "owner2"}
assert orm_subdag.is_active
assert orm_subdag.safe_dag_id == "dag__dot__subtask"
assert orm_subdag.fileloc == orm_dag.fileloc
session.close()
def test_sync_to_db_default_view(self):
dag = DAG(
"dag",
start_date=DEFAULT_DATE,
default_view="graph",
)
with dag:
EmptyOperator(task_id="task", owner="owner1")
SubDagOperator(
task_id="subtask",
owner="owner2",
subdag=DAG(
"dag.subtask",
start_date=DEFAULT_DATE,
),
)
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == "dag").one()
assert orm_dag.default_view is not None
assert orm_dag.default_view == "graph"
session.close()
@provide_session
def test_is_paused_subdag(self, session):
subdag_id = "dag.subdag"
subdag = DAG(
subdag_id,
start_date=DEFAULT_DATE,
)
with subdag:
EmptyOperator(
task_id="dummy_task",
)
dag_id = "dag"
dag = DAG(
dag_id,
start_date=DEFAULT_DATE,
)
with dag:
SubDagOperator(task_id="subdag", subdag=subdag)
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
session.query(DagModel).filter(DagModel.dag_id.in_([subdag_id, dag_id])).delete(
synchronize_session=False
)
dag.sync_to_db(session=session)
unpaused_dags = (
session.query(DagModel.dag_id, DagModel.is_paused)
.filter(
DagModel.dag_id.in_([subdag_id, dag_id]),
)
.all()
)
assert {
(dag_id, False),
(subdag_id, False),
} == set(unpaused_dags)
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True, including_subdags=False)
paused_dags = (
session.query(DagModel.dag_id, DagModel.is_paused)
.filter(
DagModel.dag_id.in_([subdag_id, dag_id]),
)
.all()
)
assert {
(dag_id, True),
(subdag_id, False),
} == set(paused_dags)
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True)
paused_dags = (
session.query(DagModel.dag_id, DagModel.is_paused)
.filter(
DagModel.dag_id.in_([subdag_id, dag_id]),
)
.all()
)
assert {
(dag_id, True),
(subdag_id, True),
} == set(paused_dags)
def test_existing_dag_is_paused_upon_creation(self):
dag = DAG("dag_paused")
dag.sync_to_db()
assert not dag.get_is_paused()
dag = DAG("dag_paused", is_paused_upon_creation=True)
dag.sync_to_db()
# Since the dag existed before, it should not follow the pause flag upon creation
assert not dag.get_is_paused()
def test_new_dag_is_paused_upon_creation(self):
dag = DAG("new_nonexisting_dag", is_paused_upon_creation=True)
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == "new_nonexisting_dag").one()
# Since the dag didn't exist before, it should follow the pause flag upon creation
assert orm_dag.is_paused
session.close()
def test_existing_dag_default_view(self):
with create_session() as session:
session.add(DagModel(dag_id="dag_default_view_old", default_view=None))
session.commit()
orm_dag = session.query(DagModel).filter(DagModel.dag_id == "dag_default_view_old").one()
assert orm_dag.default_view is None
assert orm_dag.get_default_view() == conf.get("webserver", "dag_default_view").lower()
def test_dag_is_deactivated_upon_dagfile_deletion(self):
dag_id = "old_existing_dag"
dag_fileloc = "/usr/local/airflow/dags/non_existing_path.py"
dag = DAG(
dag_id,
is_paused_upon_creation=True,
)
dag.fileloc = dag_fileloc
session = settings.Session()
with mock.patch("airflow.models.dag.DagCode.bulk_sync_to_db"):
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
assert orm_dag.is_active
assert orm_dag.fileloc == dag_fileloc
DagModel.deactivate_deleted_dags(list_py_file_paths(settings.DAGS_FOLDER))
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
assert not orm_dag.is_active
session.execute(DagModel.__table__.delete().where(DagModel.dag_id == dag_id))
session.close()
def test_dag_naive_default_args_start_date_with_timezone(self):
local_tz = pendulum.timezone("Europe/Zurich")
default_args = {"start_date": datetime.datetime(2018, 1, 1, tzinfo=local_tz)}
dag = DAG("DAG", default_args=default_args)
assert dag.timezone.name == local_tz.name
dag = DAG("DAG", default_args=default_args)
assert dag.timezone.name == local_tz.name
def test_roots(self):
"""Verify if dag.roots returns the root tasks of a DAG."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
op2 = EmptyOperator(task_id="t2")
op3 = EmptyOperator(task_id="t3")
op4 = EmptyOperator(task_id="t4")
op5 = EmptyOperator(task_id="t5")
[op1, op2] >> op3 >> [op4, op5]
assert set(dag.roots) == {op1, op2}
def test_leaves(self):
"""Verify if dag.leaves returns the leaf tasks of a DAG."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
op2 = EmptyOperator(task_id="t2")
op3 = EmptyOperator(task_id="t3")
op4 = EmptyOperator(task_id="t4")
op5 = EmptyOperator(task_id="t5")
[op1, op2] >> op3 >> [op4, op5]
assert set(dag.leaves) == {op4, op5}
def test_tree_view(self):
"""Verify correctness of dag.tree_view()."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
op2 = EmptyOperator(task_id="t2")
op3 = EmptyOperator(task_id="t3")
op1 >> op2 >> op3
with redirect_stdout(io.StringIO()) as stdout:
dag.tree_view()
stdout = stdout.getvalue()
stdout_lines = stdout.split("\n")
assert "t1" in stdout_lines[0]
assert "t2" in stdout_lines[1]
assert "t3" in stdout_lines[2]
def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
op2 = BashOperator(task_id="t1", bash_command="sleep 1")
op1 >> op2
assert dag.task_dict == {op1.task_id: op1}
def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"):
dag = DAG("test_dag", start_date=DEFAULT_DATE)
op1 = EmptyOperator(task_id="t1", dag=dag)
op2 = EmptyOperator(task_id="t1", dag=dag)
op1 >> op2
assert dag.task_dict == {op1.task_id: op1}
def test_duplicate_task_ids_for_same_task_is_allowed(self):
"""Verify that same tasks with Duplicate task_id do not raise error"""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = op2 = EmptyOperator(task_id="t1")
op3 = EmptyOperator(task_id="t3")
op1 >> op3
op2 >> op3
assert op1 == op2
assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
def test_partial_subset_updates_all_references_while_deepcopy(self):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
op2 = EmptyOperator(task_id="t2")
op3 = EmptyOperator(task_id="t3")
op1 >> op2
op2 >> op3
partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False)
assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial)
# Copied DAG should not include unused task IDs in used_group_ids
assert "t3" not in partial.task_group.used_group_ids
def test_partial_subset_taskgroup_join_ids(self):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
start = EmptyOperator(task_id="start")
with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group:
with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1:
EmptyOperator(task_id="t1")
with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2:
EmptyOperator(task_id="t2")
start >> tg1 >> tg2
# Pre-condition checks
task = dag.get_task("t2")
assert task.task_group.upstream_group_ids == {"tg1"}
assert isinstance(task.task_group.parent_group, weakref.ProxyType)
assert task.task_group.parent_group == outer_group
partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False)
copied_task = partial.get_task("t2")
assert copied_task.task_group.upstream_group_ids == {"tg1"}
assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType)
assert copied_task.task_group.parent_group
# Make sure we don't affect the original!
assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids
def test_schedule_dag_no_previous_runs(self):
"""
Tests scheduling a dag with no previous runs
"""
dag_id = "test_schedule_dag_no_previous_runs"
dag = DAG(dag_id=dag_id)
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE))
dag_run = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=TEST_DATE,
state=State.RUNNING,
)
assert dag_run is not None
assert dag.dag_id == dag_run.dag_id
assert dag_run.run_id is not None
assert "" != dag_run.run_id
assert (
TEST_DATE == dag_run.execution_date
), f"dag_run.execution_date did not match expectation: {dag_run.execution_date}"
assert State.RUNNING == dag_run.state
assert not dag_run.external_trigger
dag.clear()
self._clean_up(dag_id)
@patch("airflow.models.dag.Stats")
def test_dag_handle_callback_crash(self, mock_stats):
"""
Tests avoid crashes from calling dag callbacks exceptions
"""
dag_id = "test_dag_callback_crash"
mock_callback_with_exception = mock.MagicMock()
mock_callback_with_exception.side_effect = Exception
dag = DAG(
dag_id=dag_id,
# callback with invalid signature should not cause crashes
on_success_callback=lambda: 1,
on_failure_callback=mock_callback_with_exception,
)
when = TEST_DATE
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=when))
with create_session() as session:
dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL, session=session)
# should not raise any exception
dag.handle_callback(dag_run, success=False)
dag.handle_callback(dag_run, success=True)
mock_stats.incr.assert_called_with(
"dag.callback_exceptions",
tags={"dag_id": "test_dag_callback_crash"},
)
dag.clear()
self._clean_up(dag_id)
def test_next_dagrun_after_fake_scheduled_previous(self):
"""
Test scheduling a dag where there is a prior DagRun
which has the same run_id as the next run should have
"""
delta = datetime.timedelta(hours=1)
dag_id = "test_schedule_dag_fake_scheduled_previous"
dag = DAG(dag_id=dag_id, schedule=delta, start_date=DEFAULT_DATE)
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=DEFAULT_DATE))
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE,
state=State.SUCCESS,
external_trigger=True,
)
dag.sync_to_db()
with create_session() as session:
model = session.get(DagModel, dag.dag_id)
# Even though there is a run for this date already, it is marked as manual/external, so we should
# create a scheduled one anyway!
assert model.next_dagrun == DEFAULT_DATE
assert model.next_dagrun_create_after == DEFAULT_DATE + delta
self._clean_up(dag_id)
def test_schedule_dag_once(self):
"""
Tests scheduling a dag scheduled for @once - should be scheduled the first time
it is called, and not scheduled the second.
"""
dag_id = "test_schedule_dag_once"
dag = DAG(dag_id=dag_id, schedule="@once")
assert isinstance(dag.timetable, OnceTimetable)
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE))
# Sync once to create the DagModel
dag.sync_to_db()
dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=TEST_DATE, state=State.SUCCESS)
# Then sync again after creating the dag run -- this should update next_dagrun
dag.sync_to_db()
with create_session() as session:
model = session.get(DagModel, dag.dag_id)
assert model.next_dagrun is None
assert model.next_dagrun_create_after is None
self._clean_up(dag_id)
def test_fractional_seconds(self):
"""
Tests if fractional seconds are stored in the database
"""
dag_id = "test_fractional_seconds"
dag = DAG(dag_id=dag_id, schedule="@once")
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake", start_date=TEST_DATE))
start_date = timezone.utcnow()
run = dag.create_dagrun(
run_id="test_" + start_date.isoformat(),
execution_date=start_date,
start_date=start_date,
state=State.RUNNING,
external_trigger=False,
)
run.refresh_from_db()
assert start_date == run.execution_date, "dag run execution_date loses precision"
assert start_date == run.start_date, "dag run start_date loses precision "
self._clean_up(dag_id)
def test_pickling(self):
test_dag_id = "test_pickling"
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG(test_dag_id, default_args=args)
dag_pickle = dag.pickle()
assert dag_pickle.pickle.dag_id == dag.dag_id
def test_rich_comparison_ops(self):
test_dag_id = "test_rich_comparison_ops"
class DAGsubclass(DAG):
pass
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG(test_dag_id, default_args=args)
dag_eq = DAG(test_dag_id, default_args=args)
dag_diff_load_time = DAG(test_dag_id, default_args=args)
dag_diff_name = DAG(test_dag_id + "_neq", default_args=args)
dag_subclass = DAGsubclass(test_dag_id, default_args=args)
dag_subclass_diff_name = DAGsubclass(test_dag_id + "2", default_args=args)
for dag_ in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]:
dag_.last_loaded = dag.last_loaded
# test identity equality
assert dag == dag
# test dag (in)equality based on _comps
assert dag_eq == dag
assert dag_diff_name != dag
assert dag_diff_load_time != dag
# test dag inequality based on type even if _comps happen to match
assert dag_subclass != dag
# a dag should equal an unpickled version of itself
dump = pickle.dumps(dag)
assert pickle.loads(dump) == dag
# dags are ordered based on dag_id no matter what the type is
assert dag < dag_diff_name
assert dag > dag_diff_load_time
assert dag < dag_subclass_diff_name
# greater than should have been created automatically by functools
assert dag_diff_name > dag
# hashes are non-random and match equality
assert hash(dag) == hash(dag)
assert hash(dag_eq) == hash(dag)
assert hash(dag_diff_name) != hash(dag)
assert hash(dag_subclass) != hash(dag)
def test_get_paused_dag_ids(self):
dag_id = "test_get_paused_dag_ids"
dag = DAG(dag_id, is_paused_upon_creation=True)
dag.sync_to_db()
assert DagModel.get_dagmodel(dag_id) is not None
paused_dag_ids = DagModel.get_paused_dag_ids([dag_id])
assert paused_dag_ids == {dag_id}
with create_session() as session:
session.query(DagModel).filter(DagModel.dag_id == dag_id).delete(synchronize_session=False)
@pytest.mark.parametrize(
"schedule_interval_arg, expected_timetable, interval_description",
[
(None, NullTimetable(), "Never, external triggers only"),
("@daily", cron_timetable("0 0 * * *"), "At 00:00"),
("@weekly", cron_timetable("0 0 * * 0"), "At 00:00, only on Sunday"),
("@monthly", cron_timetable("0 0 1 * *"), "At 00:00, on day 1 of the month"),
("@quarterly", cron_timetable("0 0 1 */3 *"), "At 00:00, on day 1 of the month, every 3 months"),
("@yearly", cron_timetable("0 0 1 1 *"), "At 00:00, on day 1 of the month, only in January"),
("5 0 * 8 *", cron_timetable("5 0 * 8 *"), "At 00:05, only in August"),
("@once", OnceTimetable(), "Once, as soon as possible"),
(datetime.timedelta(days=1), delta_timetable(datetime.timedelta(days=1)), ""),
("30 21 * * 5 1", cron_timetable("30 21 * * 5 1"), ""),
],
)
def test_timetable_and_description_from_schedule_interval_arg(
self, schedule_interval_arg, expected_timetable, interval_description
):
dag = DAG("test_schedule_interval_arg", schedule=schedule_interval_arg)
assert dag.timetable == expected_timetable
assert dag.schedule_interval == schedule_interval_arg
assert dag.timetable.description == interval_description
def test_timetable_and_description_from_dataset(self):
dag = DAG("test_schedule_interval_arg", schedule=[Dataset(uri="hello")])
assert dag.timetable == DatasetTriggeredTimetable()
assert dag.schedule_interval == "Dataset"
assert dag.timetable.description == "Triggered by datasets"
def test_schedule_interval_still_works(self):
dag = DAG("test_schedule_interval_arg", schedule_interval="*/5 * * * *")
assert dag.timetable == cron_timetable("*/5 * * * *")
assert dag.schedule_interval == "*/5 * * * *"
assert dag.timetable.description == "Every 5 minutes"
def test_timetable_still_works(self):
dag = DAG("test_schedule_interval_arg", timetable=cron_timetable("*/6 * * * *"))
assert dag.timetable == cron_timetable("*/6 * * * *")
assert dag.schedule_interval == "*/6 * * * *"
assert dag.timetable.description == "Every 6 minutes"
@pytest.mark.parametrize(
"timetable, expected_description",
[
(NullTimetable(), "Never, external triggers only"),
(cron_timetable("0 0 * * *"), "At 00:00"),
(cron_timetable("@daily"), "At 00:00"),
(cron_timetable("0 0 * * 0"), "At 00:00, only on Sunday"),
(cron_timetable("@weekly"), "At 00:00, only on Sunday"),
(cron_timetable("0 0 1 * *"), "At 00:00, on day 1 of the month"),
(cron_timetable("@monthly"), "At 00:00, on day 1 of the month"),
(cron_timetable("0 0 1 */3 *"), "At 00:00, on day 1 of the month, every 3 months"),
(cron_timetable("@quarterly"), "At 00:00, on day 1 of the month, every 3 months"),
(cron_timetable("0 0 1 1 *"), "At 00:00, on day 1 of the month, only in January"),
(cron_timetable("@yearly"), "At 00:00, on day 1 of the month, only in January"),
(cron_timetable("5 0 * 8 *"), "At 00:05, only in August"),
(OnceTimetable(), "Once, as soon as possible"),
(delta_timetable(datetime.timedelta(days=1)), ""),
(cron_timetable("30 21 * * 5 1"), ""),
],
)
def test_description_from_timetable(self, timetable, expected_description):
dag = DAG("test_schedule_interval_description", timetable=timetable)
assert dag.timetable == timetable
assert dag.timetable.description == expected_description
def test_create_dagrun_run_id_is_generated(self):
dag = DAG(dag_id="run_id_is_generated")
dr = dag.create_dagrun(run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, state=State.NONE)
assert dr.run_id == f"manual__{DEFAULT_DATE.isoformat()}"
def test_create_dagrun_run_type_is_obtained_from_run_id(self):
dag = DAG(dag_id="run_type_is_obtained_from_run_id")
dr = dag.create_dagrun(run_id="scheduled__", state=State.NONE)
assert dr.run_type == DagRunType.SCHEDULED
dr = dag.create_dagrun(run_id="custom_is_set_to_manual", state=State.NONE)
assert dr.run_type == DagRunType.MANUAL
def test_create_dagrun_job_id_is_set(self):
job_id = 42
dag = DAG(dag_id="test_create_dagrun_job_id_is_set")
dr = dag.create_dagrun(
run_id="test_create_dagrun_job_id_is_set", state=State.NONE, creating_job_id=job_id
)
assert dr.creating_job_id == job_id
def test_dag_add_task_checks_trigger_rule(self):
# A non fail stop dag should allow any trigger rule
from airflow.exceptions import DagInvalidTriggerRule
from airflow.utils.trigger_rule import TriggerRule
task_with_non_default_trigger_rule = EmptyOperator(
task_id="task_with_non_default_trigger_rule", trigger_rule=TriggerRule.DUMMY
)
non_fail_stop_dag = DAG(
dag_id="test_dag_add_task_checks_trigger_rule", start_date=DEFAULT_DATE, fail_stop=False
)
try:
non_fail_stop_dag.add_task(task_with_non_default_trigger_rule)
except DagInvalidTriggerRule as exception:
assert False, f"dag add_task() raises DagInvalidTriggerRule for non fail stop dag: {exception}"
# a fail stop dag should allow default trigger rule
from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
fail_stop_dag = DAG(
dag_id="test_dag_add_task_checks_trigger_rule", start_date=DEFAULT_DATE, fail_stop=True
)
task_with_default_trigger_rule = EmptyOperator(
task_id="task_with_default_trigger_rule", trigger_rule=DEFAULT_TRIGGER_RULE
)
try:
fail_stop_dag.add_task(task_with_default_trigger_rule)
except DagInvalidTriggerRule as exception:
assert (
False
), f"dag.add_task() raises exception for fail-stop dag & default trigger rule: {exception}"
# a fail stop dag should not allow a non-default trigger rule
with pytest.raises(DagInvalidTriggerRule):
fail_stop_dag.add_task(task_with_non_default_trigger_rule)
def test_dag_add_task_sets_default_task_group(self):
dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", start_date=DEFAULT_DATE)
task_without_task_group = EmptyOperator(task_id="task_without_group_id")
default_task_group = TaskGroupContext.get_current_task_group(dag)
dag.add_task(task_without_task_group)
assert default_task_group.get_child_by_label("task_without_group_id") == task_without_task_group
task_group = TaskGroup(group_id="task_group", dag=dag)
task_with_task_group = EmptyOperator(task_id="task_with_task_group", task_group=task_group)
dag.add_task(task_with_task_group)
assert task_group.get_child_by_label("task_with_task_group") == task_with_task_group
assert dag.get_task("task_group.task_with_task_group") == task_with_task_group
@pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING])
def test_clear_set_dagrun_state(self, dag_run_state):
dag_id = "test_clear_set_dagrun_state"
self._clean_up(dag_id)
task_id = "t1"
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = EmptyOperator(task_id=task_id, dag=dag)
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)
task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING)
session.merge(task_instance_1)
session.commit()
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,
include_subdags=False,
include_parentdag=False,
session=session,
)
dagruns = (
session.query(
DagRun,
)
.filter(
DagRun.dag_id == dag_id,
)
.all()
)
assert len(dagruns) == 1
dagrun: DagRun = dagruns[0]
assert dagrun.state == dag_run_state
@pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING])
def test_clear_set_dagrun_state_for_mapped_task(self, dag_run_state):
dag_id = "test_clear_set_dagrun_state"
self._clean_up(dag_id)
task_id = "t1"
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
@dag.task
def make_arg_lists():
return [[1], [2], [{"a": "b"}]]
def consumer(value):
print(value)
mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand(
op_args=make_arg_lists()
)
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
session=session,
)
expand_mapped_task(mapped, dagrun_1.run_id, "make_arg_lists", length=2, session=session)
upstream_ti = dagrun_1.get_task_instance("make_arg_lists", session=session)
ti = dagrun_1.get_task_instance(task_id, map_index=0, session=session)
ti2 = dagrun_1.get_task_instance(task_id, map_index=1, session=session)
upstream_ti.state = State.SUCCESS
ti.state = State.SUCCESS
ti2.state = State.SUCCESS
session.flush()
dag.clear(
task_ids=[(task_id, 0), ("make_arg_lists")],
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,
include_subdags=False,
include_parentdag=False,
session=session,
)
session.refresh(upstream_ti)
session.refresh(ti)
session.refresh(ti2)
assert upstream_ti.state is None # cleared
assert ti.state is None # cleared
assert ti2.state == State.SUCCESS # not cleared
dagruns = (
session.query(
DagRun,
)
.filter(
DagRun.dag_id == dag_id,
)
.all()
)
assert len(dagruns) == 1
dagrun: DagRun = dagruns[0]
assert dagrun.state == dag_run_state
def test_dag_test_basic(self):
dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE)
mock_object = mock.MagicMock()
@task_decorator
def check_task():
# we call a mock object to ensure that this task actually ran.
mock_object()
with dag:
check_task()
dag.test()
mock_object.assert_called_once()
def test_dag_test_with_dependencies(self):
dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE)
mock_object = mock.MagicMock()
@task_decorator
def check_task():
return "output of first task"
@task_decorator
def check_task_2(my_input):
# we call a mock object to ensure that this task actually ran.
mock_object(my_input)
with dag:
check_task_2(check_task())
dag.test()
mock_object.assert_called_with("output of first task")
def test_dag_test_with_task_mapping(self):
dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE)
mock_object = mock.MagicMock()
@task_decorator()
def get_index(current_val, ti=None):
return ti.map_index
@task_decorator
def check_task(my_input):
# we call a mock object with the combined map to ensure all expected indexes are called
mock_object(list(my_input))
with dag:
mapped_task = get_index.expand(current_val=[1, 1, 1, 1, 1])
check_task(mapped_task)
dag.test()
mock_object.assert_called_with([0, 1, 2, 3, 4])
def test_dag_connection_file(self):
test_connections_string = """
---
my_postgres_conn:
- conn_id: my_postgres_conn
conn_type: postgres
"""
dag = DAG(dag_id="test_local_testing_conn_file", start_date=DEFAULT_DATE)
@task_decorator
def check_task():
from airflow.configuration import secrets_backend_list
from airflow.secrets.local_filesystem import LocalFilesystemBackend
assert isinstance(secrets_backend_list[0], LocalFilesystemBackend)
local_secrets: LocalFilesystemBackend = secrets_backend_list[0]
assert local_secrets.get_connection("my_postgres_conn").conn_id == "my_postgres_conn"
with dag:
check_task()
with NamedTemporaryFile(suffix=".yaml") as tmp:
with open(tmp.name, "w") as f:
f.write(test_connections_string)
dag.test(conn_file_path=tmp.name)
def _make_test_subdag(self, session):
dag_id = "test_subdag"
self._clean_up(dag_id)
task_id = "t1"
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = EmptyOperator(task_id=task_id, dag=dag)
subdag = DAG(dag_id + ".test", start_date=DEFAULT_DATE, max_active_runs=1)
SubDagOperator(task_id="test", subdag=subdag, dag=dag)
t_2 = EmptyOperator(task_id="task", dag=subdag)
subdag.parent_dag = dag
dag.sync_to_db()
session = settings.Session()
dag.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
session=session,
)
subdag.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
session=session,
)
task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING)
task_instance_2 = TI(t_2, execution_date=DEFAULT_DATE, state=State.RUNNING)
session.merge(task_instance_1)
session.merge(task_instance_2)
return dag, subdag
@pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING])
def test_clear_set_dagrun_state_for_subdag(self, dag_run_state):
session = settings.Session()
dag, subdag = self._make_test_subdag(session)
session.flush()
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,
include_subdags=True,
include_parentdag=False,
session=session,
)
dagrun = (
session.query(
DagRun,
)
.filter(DagRun.dag_id == subdag.dag_id)
.one()
)
assert dagrun.state == dag_run_state
session.rollback()
@pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING])
def test_clear_set_dagrun_state_for_parent_dag(self, dag_run_state):
session = settings.Session()
dag, subdag = self._make_test_subdag(session)
session.flush()
subdag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,
include_subdags=True,
include_parentdag=True,
session=session,
)
dagrun = (
session.query(
DagRun,
)
.filter(DagRun.dag_id == dag.dag_id)
.one()
)
assert dagrun.state == dag_run_state
@pytest.mark.parametrize(
"ti_state_begin, ti_state_end",
[
*((state, None) for state in State.task_states if state != TaskInstanceState.RUNNING),
(TaskInstanceState.RUNNING, TaskInstanceState.RESTARTING),
],
)
def test_clear_dag(
self,
ti_state_begin: TaskInstanceState | None,
ti_state_end: TaskInstanceState | None,
):
dag_id = "test_clear_dag"
self._clean_up(dag_id)
task_id = "t1"
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = EmptyOperator(task_id=task_id, dag=dag)
session = settings.Session() # type: ignore
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=DagRunState.RUNNING,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)
task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=ti_state_begin)
task_instance_1.job_id = 123
session.merge(task_instance_1)
session.commit()
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
session=session,
)
task_instances = (
session.query(
TI,
)
.filter(
TI.dag_id == dag_id,
)
.all()
)
assert len(task_instances) == 1
task_instance: TI = task_instances[0]
assert task_instance.state == ti_state_end
self._clean_up(dag_id)
def test_next_dagrun_info_once(self):
dag = DAG("test_scheduler_dagrun_once", start_date=timezone.datetime(2015, 1, 1), schedule="@once")
next_info = dag.next_dagrun_info(None)
assert next_info and next_info.logical_date == timezone.datetime(2015, 1, 1)
next_info = dag.next_dagrun_info(next_info.data_interval)
assert next_info is None
def test_next_dagrun_info_start_end_dates(self):
"""
Tests that an attempt to schedule a task after the Dag's end_date
does not succeed.
"""
delta = datetime.timedelta(hours=1)
runs = 3
start_date = DEFAULT_DATE
end_date = start_date + (runs - 1) * delta
dag_id = "test_schedule_dag_start_end_dates"
dag = DAG(dag_id=dag_id, start_date=start_date, end_date=end_date, schedule=delta)
dag.add_task(BaseOperator(task_id="faketastic", owner="Also fake"))
# Create and schedule the dag runs
dates = []
interval = None
for _ in range(runs):
next_info = dag.next_dagrun_info(interval)
if next_info is None:
dates.append(None)
else:
interval = next_info.data_interval
dates.append(interval.start)
assert all(date is not None for date in dates)
assert dates[-1] == end_date
assert dag.next_dagrun_info(interval.start) is None
def test_next_dagrun_info_catchup(self):
"""
Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date
"""
def make_dag(dag_id, schedule, start_date, catchup):
default_args = {
"owner": "airflow",
"depends_on_past": False,
}
dag = DAG(
dag_id,
schedule=schedule,
start_date=start_date,
catchup=catchup,
default_args=default_args,
)
op1 = EmptyOperator(task_id="t1", dag=dag)
op2 = EmptyOperator(task_id="t2", dag=dag)
op3 = EmptyOperator(task_id="t3", dag=dag)
op1 >> op2 >> op3
return dag
now = timezone.utcnow()
six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace(
minute=0, second=0, microsecond=0
)
half_an_hour_ago = now - datetime.timedelta(minutes=30)
two_hours_ago = now - datetime.timedelta(hours=2)
dag1 = make_dag(
dag_id="dag_without_catchup_ten_minute",
schedule="*/10 * * * *",
start_date=six_hours_ago_to_the_hour,
catchup=False,
)
next_date, _ = dag1.next_dagrun_info(None)
# The DR should be scheduled in the last half an hour, not 6 hours ago
assert next_date > half_an_hour_ago
assert next_date < timezone.utcnow()
dag2 = make_dag(
dag_id="dag_without_catchup_hourly",
schedule="@hourly",
start_date=six_hours_ago_to_the_hour,
catchup=False,
)
next_date, _ = dag2.next_dagrun_info(None)
# The DR should be scheduled in the last 2 hours, not 6 hours ago
assert next_date > two_hours_ago
# The DR should be scheduled BEFORE now
assert next_date < timezone.utcnow()
dag3 = make_dag(
dag_id="dag_without_catchup_once",
schedule="@once",
start_date=six_hours_ago_to_the_hour,
catchup=False,
)
next_date, _ = dag3.next_dagrun_info(None)
# The DR should be scheduled in the last 2 hours, not 6 hours ago
assert next_date == six_hours_ago_to_the_hour
@time_machine.travel(timezone.datetime(2020, 1, 5), tick=False)
def test_next_dagrun_info_timedelta_schedule_and_catchup_false(self):
"""
Test that the dag file processor does not create multiple dagruns
if a dag is scheduled with 'timedelta' and catchup=False
"""
dag = DAG(
"test_scheduler_dagrun_once_with_timedelta_and_catchup_false",
start_date=timezone.datetime(2015, 1, 1),
schedule=timedelta(days=1),
catchup=False,
)
next_info = dag.next_dagrun_info(None)
assert next_info and next_info.logical_date == timezone.datetime(2020, 1, 4)
# The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns"
next_info = dag.next_dagrun_info(next_info.data_interval)
assert next_info and next_info.logical_date == timezone.datetime(2020, 1, 5)
@time_machine.travel(timezone.datetime(2020, 5, 4))
def test_next_dagrun_info_timedelta_schedule_and_catchup_true(self):
"""
Test that the dag file processor creates multiple dagruns
if a dag is scheduled with 'timedelta' and catchup=True
"""
dag = DAG(
"test_scheduler_dagrun_once_with_timedelta_and_catchup_true",
start_date=timezone.datetime(2020, 5, 1),
schedule=timedelta(days=1),
catchup=True,
)
next_info = dag.next_dagrun_info(None)
assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 1)
next_info = dag.next_dagrun_info(next_info.data_interval)
assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 2)
next_info = dag.next_dagrun_info(next_info.data_interval)
assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 3)
# The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns"
next_info = dag.next_dagrun_info(next_info.data_interval)
assert next_info and next_info.logical_date == timezone.datetime(2020, 5, 4)
def test_next_dagrun_info_timetable_exception(self, caplog):
"""Test the DAG does not crash the scheduler if the timetable raises an exception."""
class FailingTimetable(Timetable):
def next_dagrun_info(self, last_automated_data_interval, restriction):
raise RuntimeError("this fails")
dag = DAG(
"test_next_dagrun_info_timetable_exception",
start_date=timezone.datetime(2020, 5, 1),
timetable=FailingTimetable(),
catchup=True,
)
def _check_logs(records: list[logging.LogRecord], data_interval: DataInterval) -> None:
assert len(records) == 1
record = records[0]
assert record.exc_info is not None, "Should contain exception"
assert record.getMessage() == (
f"Failed to fetch run info after data interval {data_interval} "
f"for DAG 'test_next_dagrun_info_timetable_exception'"
)
with caplog.at_level(level=logging.ERROR):
next_info = dag.next_dagrun_info(None)
assert next_info is None, "failed next_dagrun_info should return None"
_check_logs(caplog.records, data_interval=None)
caplog.clear()
data_interval = DataInterval(timezone.datetime(2020, 5, 1), timezone.datetime(2020, 5, 2))
with caplog.at_level(level=logging.ERROR):
next_info = dag.next_dagrun_info(data_interval)
assert next_info is None, "failed next_dagrun_info should return None"
_check_logs(caplog.records, data_interval)
def test_next_dagrun_after_auto_align(self):
"""
Test if the schedule_interval will be auto aligned with the start_date
such that if the start_date coincides with the schedule the first
execution_date will be start_date, otherwise it will be start_date +
interval.
"""
dag = DAG(
dag_id="test_scheduler_auto_align_1",
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
schedule="4 5 * * *",
)
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
next_info = dag.next_dagrun_info(None)
assert next_info and next_info.logical_date == timezone.datetime(2016, 1, 2, 5, 4)
dag = DAG(
dag_id="test_scheduler_auto_align_2",
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
schedule="10 10 * * *",
)
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
next_info = dag.next_dagrun_info(None)
assert next_info and next_info.logical_date == timezone.datetime(2016, 1, 1, 10, 10)
def test_next_dagrun_after_not_for_subdags(self):
"""
Test the subdags are never marked to have dagruns created, as they are
handled by the SubDagOperator, not the scheduler
"""
def subdag(parent_dag_name, child_dag_name, args):
"""
Create a subdag.
"""
dag_subdag = DAG(
dag_id=f"{parent_dag_name}.{child_dag_name}",
schedule="@daily",
default_args=args,
)
for i in range(2):
EmptyOperator(task_id=f"{child_dag_name}-task-{i + 1}", dag=dag_subdag)
return dag_subdag
with DAG(
dag_id="test_subdag_operator",
start_date=datetime.datetime(2019, 1, 1),
max_active_runs=1,
schedule=timedelta(minutes=1),
) as dag:
section_1 = SubDagOperator(
task_id="section-1",
subdag=subdag(dag.dag_id, "section-1", {"start_date": dag.start_date}),
)
subdag = section_1.subdag
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
next_parent_info = dag.next_dagrun_info(None)
assert next_parent_info.logical_date == timezone.datetime(2019, 1, 1, 0, 0)
next_subdag_info = subdag.next_dagrun_info(None)
assert next_subdag_info is None, "SubDags should never have DagRuns created by the scheduler"
def test_replace_outdated_access_control_actions(self):
outdated_permissions = {
"role1": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
"role2": {permissions.DEPRECATED_ACTION_CAN_DAG_READ, permissions.DEPRECATED_ACTION_CAN_DAG_EDIT},
}
updated_permissions = {
"role1": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
"role2": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
}
with pytest.warns(DeprecationWarning):
dag = DAG(dag_id="dag_with_outdated_perms", access_control=outdated_permissions)
assert dag.access_control == updated_permissions
with pytest.warns(DeprecationWarning):
dag.access_control = outdated_permissions
assert dag.access_control == updated_permissions
def test_validate_params_on_trigger_dag(self):
dag = models.DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")})
with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
dag.create_dagrun(
run_id="test_dagrun_missing_param",
state=State.RUNNING,
execution_date=TEST_DATE,
)
dag = models.DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")})
with pytest.raises(
ParamValidationError, match="Invalid input for param param1: None is not of type 'string'"
):
dag.create_dagrun(
run_id="test_dagrun_missing_param",
state=State.RUNNING,
execution_date=TEST_DATE,
conf={"param1": None},
)
dag = models.DAG("dummy-dag", schedule=None, params={"param1": Param(type="string")})
dag.create_dagrun(
run_id="test_dagrun_missing_param",
state=State.RUNNING,
execution_date=TEST_DATE,
conf={"param1": "hello"},
)
def test_return_date_range_with_num_method(self):
start_date = TEST_DATE
delta = timedelta(days=1)
dag = models.DAG("dummy-dag", schedule=delta)
dag_dates = dag.date_range(start_date=start_date, num=3)
assert dag_dates == [
start_date,
start_date + delta,
start_date + 2 * delta,
]
def test_dag_owner_links(self):
dag = DAG(
"dag",
start_date=DEFAULT_DATE,
owner_links={"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"},
)
assert dag.owner_links == {"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"}
session = settings.Session()
dag.sync_to_db(session=session)
expected_owners = {"dag": {"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"}}
orm_dag_owners = DagOwnerAttributes.get_all(session)
assert orm_dag_owners == expected_owners
# Test dag owner links are removed completely
dag = DAG(
"dag",
start_date=DEFAULT_DATE,
)
dag.sync_to_db(session=session)
orm_dag_owners = session.query(DagOwnerAttributes).all()
assert not orm_dag_owners
# Check wrong formatted owner link
with pytest.raises(AirflowException):
DAG("dag", start_date=DEFAULT_DATE, owner_links={"owner1": "my-bad-link"})
@pytest.mark.parametrize(
"kwargs",
[
{"schedule_interval": "@daily", "schedule": "@weekly"},
{"timetable": NullTimetable(), "schedule": "@weekly"},
{"timetable": NullTimetable(), "schedule_interval": "@daily"},
],
ids=[
"schedule_interval+schedule",
"timetable+schedule",
"timetable+schedule_interval",
],
)
def test_schedule_dag_param(self, kwargs):
with pytest.raises(ValueError, match="At most one"):
with DAG(dag_id="hello", **kwargs):
pass
def test_continuous_schedule_interval_limits_max_active_runs(self):
dag = DAG("continuous", start_date=DEFAULT_DATE, schedule_interval="@continuous", max_active_runs=1)
assert isinstance(dag.timetable, ContinuousTimetable)
assert dag.max_active_runs == 1
dag = DAG("continuous", start_date=DEFAULT_DATE, schedule_interval="@continuous", max_active_runs=0)
assert isinstance(dag.timetable, ContinuousTimetable)
assert dag.max_active_runs == 0
with pytest.raises(AirflowException):
dag = DAG(
"continuous", start_date=DEFAULT_DATE, schedule_interval="@continuous", max_active_runs=25
)
class TestDagModel:
def _clean(self):
clear_db_dags()
clear_db_datasets()
clear_db_runs()
def setup_method(self):
self._clean()
def teardown_method(self):
self._clean()
def test_dags_needing_dagruns_not_too_early(self):
dag = DAG(dag_id="far_future_dag", start_date=timezone.datetime(2038, 1, 1))
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
session = settings.Session()
orm_dag = DagModel(
dag_id=dag.dag_id,
max_active_tasks=1,
has_task_concurrency_limits=False,
next_dagrun=dag.start_date,
next_dagrun_create_after=timezone.datetime(2038, 1, 2),
is_active=True,
)
session.add(orm_dag)
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
dag_models = query.all()
assert dag_models == []
session.rollback()
session.close()
def test_dags_needing_dagruns_datasets(self, dag_maker, session):
dataset = Dataset(uri="hello")
with dag_maker(
session=session,
dag_id="my_dag",
max_active_runs=1,
schedule=[dataset],
start_date=pendulum.now().add(days=-2),
) as dag:
EmptyOperator(task_id="dummy")
# there's no queue record yet, so no runs needed at this time.
query, _ = DagModel.dags_needing_dagruns(session)
dag_models = query.all()
assert dag_models == []
# add queue records so we'll need a run
dag_model = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).one()
dataset_model: DatasetModel = dag_model.schedule_datasets[0]
session.add(DatasetDagRunQueue(dataset_id=dataset_model.id, target_dag_id=dag_model.dag_id))
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
dag_models = query.all()
assert dag_models == [dag_model]
# create run so we don't need a run anymore (due to max active runs)
dag_maker.create_dagrun(
run_type=DagRunType.DATASET_TRIGGERED,
state=DagRunState.QUEUED,
execution_date=pendulum.now("UTC"),
)
query, _ = DagModel.dags_needing_dagruns(session)
dag_models = query.all()
assert dag_models == []
# increase max active runs and we should now need another run
dag_maker.dag_model.max_active_runs = 2
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
dag_models = query.all()
assert dag_models == [dag_model]
def test_max_active_runs_not_none(self):
dag = DAG(dag_id="test_max_active_runs_not_none", start_date=timezone.datetime(2038, 1, 1))
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
session = settings.Session()
orm_dag = DagModel(
dag_id=dag.dag_id,
has_task_concurrency_limits=False,
next_dagrun=None,
next_dagrun_create_after=None,
is_active=True,
)
# assert max_active_runs updated
assert orm_dag.max_active_runs == 16
session.add(orm_dag)
session.flush()
assert orm_dag.max_active_runs is not None
session.rollback()
session.close()
def test_dags_needing_dagruns_only_unpaused(self):
"""
We should never create dagruns for unpaused DAGs
"""
dag = DAG(dag_id="test_dags", start_date=DEFAULT_DATE)
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
session = settings.Session()
orm_dag = DagModel(
dag_id=dag.dag_id,
has_task_concurrency_limits=False,
next_dagrun=DEFAULT_DATE,
next_dagrun_create_after=DEFAULT_DATE + timedelta(days=1),
is_active=True,
)
session.add(orm_dag)
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
needed = query.all()
assert needed == [orm_dag]
orm_dag.is_paused = True
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
dag_models = query.all()
assert dag_models == []
session.rollback()
session.close()
def test_dags_needing_dagruns_doesnot_send_dagmodel_with_import_errors(self, session):
"""
We check that has_import_error is false for dags
being set to scheduler to create dagruns
"""
dag = DAG(dag_id="test_dags", start_date=DEFAULT_DATE)
EmptyOperator(task_id="dummy", dag=dag, owner="airflow")
orm_dag = DagModel(
dag_id=dag.dag_id,
has_task_concurrency_limits=False,
next_dagrun=DEFAULT_DATE,
next_dagrun_create_after=DEFAULT_DATE + timedelta(days=1),
is_active=True,
)
assert not orm_dag.has_import_errors
session.add(orm_dag)
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
needed = query.all()
assert needed == [orm_dag]
orm_dag.has_import_errors = True
session.merge(orm_dag)
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
needed = query.all()
assert needed == []
@pytest.mark.parametrize(
("fileloc", "expected_relative"),
[
(os.path.join(settings.DAGS_FOLDER, "a.py"), Path("a.py")),
("/tmp/foo.py", Path("/tmp/foo.py")),
],
)
def test_relative_fileloc(self, fileloc, expected_relative):
dag = DAG(dag_id="test")
dag.fileloc = fileloc
assert dag.relative_fileloc == expected_relative
@pytest.mark.parametrize(
"reader_dags_folder", [settings.DAGS_FOLDER, str(repo_root / "airflow/example_dags")]
)
@pytest.mark.parametrize(
("fileloc", "expected_relative"),
[
(str(Path(settings.DAGS_FOLDER, "a.py")), Path("a.py")),
("/tmp/foo.py", Path("/tmp/foo.py")),
],
)
def test_relative_fileloc_serialized(
self, fileloc, expected_relative, session, clear_dags, reader_dags_folder
):
"""
The serialized dag model includes the dags folder as configured on the thing serializing
the dag. On the thing deserializing the dag, when determining relative fileloc,
we should use the dags folder of the processor. So even if the dags folder of
the deserializer is different (meaning that the full path is no longer relative to
the dags folder) then we should still get the relative fileloc as it existed on the
serializer process. When the full path is not relative to the configured dags folder,
then relative fileloc should just be the full path.
"""
dag = DAG(dag_id="test")
dag.fileloc = fileloc
sdm = SerializedDagModel(dag)
session.add(sdm)
session.commit()
session.expunge_all()
sdm = SerializedDagModel.get(dag.dag_id, session)
dag = sdm.dag
with conf_vars({("core", "dags_folder"): reader_dags_folder}):
assert dag.relative_fileloc == expected_relative
def test__processor_dags_folder(self, session):
"""Only populated after deserializtion"""
dag = DAG(dag_id="test")
dag.fileloc = "/abc/test.py"
assert dag._processor_dags_folder is None
sdm = SerializedDagModel(dag)
assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER
@pytest.mark.need_serialized_dag
def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, session, dag_maker):
dataset1 = Dataset(uri="ds1")
dataset2 = Dataset(uri="ds2")
for dag_id, dataset in [("datasets-1", dataset1), ("datasets-2", dataset2)]:
with dag_maker(dag_id=dag_id, start_date=timezone.utcnow(), session=session):
EmptyOperator(task_id="task", outlets=[dataset])
dr = dag_maker.create_dagrun()
ds_id = session.query(DatasetModel.id).filter_by(uri=dataset.uri).scalar()
session.add(
DatasetEvent(
dataset_id=ds_id,
source_task_id="task",
source_dag_id=dr.dag_id,
source_run_id=dr.run_id,
source_map_index=-1,
)
)
ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar()
ds2_id = session.query(DatasetModel.id).filter_by(uri=dataset2.uri).scalar()
with dag_maker(dag_id="datasets-consumer-multiple", schedule=[dataset1, dataset2]) as dag:
pass
session.flush()
session.add_all(
[
DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag.dag_id, created_at=DEFAULT_DATE),
DatasetDagRunQueue(
dataset_id=ds2_id, target_dag_id=dag.dag_id, created_at=DEFAULT_DATE + timedelta(hours=1)
),
]
)
session.flush()
query, dataset_triggered_dag_info = DagModel.dags_needing_dagruns(session)
assert 1 == len(dataset_triggered_dag_info)
assert dag.dag_id in dataset_triggered_dag_info
first_queued_time, last_queued_time = dataset_triggered_dag_info[dag.dag_id]
assert first_queued_time == DEFAULT_DATE
assert last_queued_time == DEFAULT_DATE + timedelta(hours=1)
class TestQueries:
def setup_method(self) -> None:
clear_db_runs()
def teardown_method(self) -> None:
clear_db_runs()
@pytest.mark.parametrize("tasks_count", [3, 12])
def test_count_number_queries(self, tasks_count):
dag = DAG("test_dagrun_query_count", start_date=DEFAULT_DATE)
for i in range(tasks_count):
EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag)
with assert_queries_count(2):
dag.create_dagrun(
run_id="test_dagrun_query_count",
state=State.RUNNING,
execution_date=TEST_DATE,
)
class TestDagDecorator:
DEFAULT_ARGS = {
"owner": "test",
"depends_on_past": True,
"start_date": timezone.utcnow(),
"retries": 1,
"retry_delay": timedelta(minutes=1),
}
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
VALUE = 42
def setup_method(self):
self.operator = None
def teardown_method(self):
clear_db_runs()
def test_fileloc(self):
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
...
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id == "noop_pipeline"
assert dag.fileloc == __file__
def test_set_dag_id(self):
"""Test that checks you can set dag_id from decorator."""
@dag_decorator("test", default_args=self.DEFAULT_ARGS)
def noop_pipeline():
...
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id == "test"
def test_default_dag_id(self):
"""Test that @dag uses function name as default dag id."""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
...
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id == "noop_pipeline"
@pytest.mark.parametrize(
argnames=["dag_doc_md", "expected_doc_md"],
argvalues=[
pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"),
pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"),
],
)
def test_documentation_added(self, dag_doc_md, expected_doc_md):
"""Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set."""
@dag_decorator(default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md)
def noop_pipeline():
"""Regular DAG documentation"""
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id == "noop_pipeline"
assert dag.doc_md == expected_doc_md
def test_documentation_template_rendered(self):
"""Test that @dag uses function docs as doc_md for DAG object"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
"""
{% if True %}
Regular DAG documentation
{% endif %}
"""
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id == "noop_pipeline"
assert "Regular DAG documentation" in dag.doc_md
def test_resolve_documentation_template_file_rendered(self):
"""Test that @dag uses function docs as doc_md for DAG object"""
with NamedTemporaryFile(suffix=".md") as f:
f.write(
b"""
{% if True %}
External Markdown DAG documentation
{% endif %}
"""
)
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)
@dag_decorator(
"test-dag", start_date=DEFAULT_DATE, template_searchpath=template_dir, doc_md=template_file
)
def markdown_docs():
...
dag = markdown_docs()
assert isinstance(dag, DAG)
assert dag.dag_id == "test-dag"
assert dag.doc_md.strip() == "External Markdown DAG documentation"
def test_fails_if_arg_not_set(self):
"""Test that @dag decorated function fails if positional argument is not set"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline(value):
@task_decorator
def return_num(num):
return num
return_num(value)
# Test that if arg is not passed it raises a type error as expected.
with pytest.raises(TypeError):
noop_pipeline()
def test_dag_param_resolves(self):
"""Test that dag param is correctly resolved by operator"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num
xcom_arg = return_num(value)
self.operator = xcom_arg.operator
dag = xcom_pass_to_op()
dr = dag.create_dagrun(
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
state=State.RUNNING,
)
self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull() == self.VALUE
def test_dag_param_dagrun_parameterized(self):
"""Test that dag param is correctly overwritten when set in dag run"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num
assert isinstance(value, DagParam)
xcom_arg = return_num(value)
self.operator = xcom_arg.operator
dag = xcom_pass_to_op()
new_value = 52
dr = dag.create_dagrun(
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
state=State.RUNNING,
conf={"value": new_value},
)
self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull() == new_value
@pytest.mark.parametrize("value", [VALUE, 0])
def test_set_params_for_dag(self, value):
"""Test that dag param is correctly set when using dag decorator"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=value):
@task_decorator
def return_num(num):
return num
xcom_arg = return_num(value)
self.operator = xcom_arg.operator
dag = xcom_pass_to_op()
assert dag.params["value"] == value
def test_warning_location(self):
# NOTE: This only works as long as there is some warning we can emit from `DAG()`
@dag_decorator(schedule_interval=None)
def mydag():
...
with pytest.warns(RemovedInAirflow3Warning) as warnings:
line = sys._getframe().f_lineno + 1
mydag()
w = warnings.pop(RemovedInAirflow3Warning)
assert w.filename == __file__
assert w.lineno == line
@pytest.mark.parametrize("timetable", [NullTimetable(), OnceTimetable()])
def test_dag_timetable_match_schedule_interval(timetable):
dag = DAG("my-dag", timetable=timetable)
assert dag._check_schedule_interval_matches_timetable()
@pytest.mark.parametrize("schedule_interval", [None, "@once", "@daily", timedelta(days=1)])
def test_dag_schedule_interval_match_timetable(schedule_interval):
dag = DAG("my-dag", schedule=schedule_interval)
assert dag._check_schedule_interval_matches_timetable()
@pytest.mark.parametrize("schedule_interval", [None, "@daily", timedelta(days=1)])
def test_dag_schedule_interval_change_after_init(schedule_interval):
dag = DAG("my-dag", timetable=OnceTimetable())
dag.schedule_interval = schedule_interval
assert not dag._check_schedule_interval_matches_timetable()
@pytest.mark.parametrize("timetable", [NullTimetable(), OnceTimetable()])
def test_dag_timetable_change_after_init(timetable):
dag = DAG("my-dag") # Default is timedelta(days=1).
dag.timetable = timetable
assert not dag._check_schedule_interval_matches_timetable()
@pytest.mark.parametrize("run_id, execution_date", [(None, datetime_tz(2020, 1, 1)), ("test-run-id", None)])
def test_set_task_instance_state(run_id, execution_date, session, dag_maker):
"""Test that set_task_instance_state updates the TaskInstance state and clear downstream failed"""
start_date = datetime_tz(2020, 1, 1)
with dag_maker("test_set_task_instance_state", start_date=start_date, session=session) as dag:
task_1 = EmptyOperator(task_id="task_1")
task_2 = EmptyOperator(task_id="task_2")
task_3 = EmptyOperator(task_id="task_3")
task_4 = EmptyOperator(task_id="task_4")
task_5 = EmptyOperator(task_id="task_5")
task_1 >> [task_2, task_3, task_4, task_5]
dagrun = dag_maker.create_dagrun(
run_id=run_id,
execution_date=execution_date,
state=State.FAILED,
run_type=DagRunType.SCHEDULED,
)
def get_ti_from_db(task):
return (
session.query(TI)
.filter(
TI.dag_id == dag.dag_id,
TI.task_id == task.task_id,
TI.run_id == dagrun.run_id,
)
.one()
)
get_ti_from_db(task_1).state = State.FAILED
get_ti_from_db(task_2).state = State.SUCCESS
get_ti_from_db(task_3).state = State.UPSTREAM_FAILED
get_ti_from_db(task_4).state = State.FAILED
get_ti_from_db(task_5).state = State.SKIPPED
session.flush()
altered = dag.set_task_instance_state(
task_id=task_1.task_id,
run_id=run_id,
execution_date=execution_date,
state=State.SUCCESS,
session=session,
)
# After _mark_task_instance_state, task_1 is marked as SUCCESS
ti1 = get_ti_from_db(task_1)
assert ti1.state == State.SUCCESS
# TIs should have DagRun pre-loaded
assert isinstance(inspect(ti1).attrs.dag_run.loaded_value, DagRun)
# task_2 remains as SUCCESS
assert get_ti_from_db(task_2).state == State.SUCCESS
# task_3 and task_4 are cleared because they were in FAILED/UPSTREAM_FAILED state
assert get_ti_from_db(task_3).state == State.NONE
assert get_ti_from_db(task_4).state == State.NONE
# task_5 remains as SKIPPED
assert get_ti_from_db(task_5).state == State.SKIPPED
dagrun.refresh_from_db(session=session)
# dagrun should be set to QUEUED
assert dagrun.get_state() == State.QUEUED
assert {t.key for t in altered} == {("test_set_task_instance_state", "task_1", dagrun.run_id, 1, -1)}
def test_set_task_instance_state_mapped(dag_maker, session):
"""Test that when setting an individual mapped TI that the other TIs are not affected"""
task_id = "t1"
with dag_maker(session=session) as dag:
@dag.task
def make_arg_lists():
return [[1], [2], [{"a": "b"}]]
def consumer(value):
print(value)
mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand(
op_args=make_arg_lists()
)
mapped >> BaseOperator(task_id="downstream")
dr1 = dag_maker.create_dagrun(
run_type=DagRunType.SCHEDULED,
state=DagRunState.FAILED,
)
expand_mapped_task(mapped, dr1.run_id, "make_arg_lists", length=2, session=session)
# set_state(future=True) only applies to scheduled runs
dr2 = dag_maker.create_dagrun(
run_type=DagRunType.SCHEDULED,
state=DagRunState.FAILED,
execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
)
expand_mapped_task(mapped, dr2.run_id, "make_arg_lists", length=2, session=session)
session.query(TI).filter_by(dag_id=dag.dag_id).update({"state": TaskInstanceState.FAILED})
ti_query = (
session.query(TI.task_id, TI.map_index, TI.run_id, TI.state)
.filter(TI.dag_id == dag.dag_id, TI.task_id.in_([task_id, "downstream"]))
.order_by(TI.run_id, TI.task_id, TI.map_index)
)
# Check pre-conditions
assert ti_query.all() == [
("downstream", -1, dr1.run_id, TaskInstanceState.FAILED),
(task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
(task_id, 1, dr1.run_id, TaskInstanceState.FAILED),
("downstream", -1, dr2.run_id, TaskInstanceState.FAILED),
(task_id, 0, dr2.run_id, TaskInstanceState.FAILED),
(task_id, 1, dr2.run_id, TaskInstanceState.FAILED),
]
dag.set_task_instance_state(
task_id=task_id,
map_indexes=[1],
future=True,
run_id=dr1.run_id,
state=TaskInstanceState.SUCCESS,
session=session,
)
assert dr1 in session, "Check session is passed down all the way"
assert ti_query.all() == [
("downstream", -1, dr1.run_id, None),
(task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
(task_id, 1, dr1.run_id, TaskInstanceState.SUCCESS),
("downstream", -1, dr2.run_id, None),
(task_id, 0, dr2.run_id, TaskInstanceState.FAILED),
(task_id, 1, dr2.run_id, TaskInstanceState.SUCCESS),
]
@pytest.mark.parametrize("run_id, execution_date", [(None, datetime_tz(2020, 1, 1)), ("test-run-id", None)])
def test_set_task_group_state(run_id, execution_date, session, dag_maker):
"""Test that set_task_group_state updates the TaskGroup state and clear downstream failed"""
start_date = datetime_tz(2020, 1, 1)
with dag_maker("test_set_task_group_state", start_date=start_date, session=session) as dag:
start = EmptyOperator(task_id="start")
with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
task_1 = EmptyOperator(task_id="task_1")
task_2 = EmptyOperator(task_id="task_2")
task_3 = EmptyOperator(task_id="task_3")
task_1 >> [task_2, task_3]
task_4 = EmptyOperator(task_id="task_4")
task_5 = EmptyOperator(task_id="task_5")
task_6 = EmptyOperator(task_id="task_6")
task_7 = EmptyOperator(task_id="task_7")
task_8 = EmptyOperator(task_id="task_8")
start >> section_1 >> [task_4, task_5, task_6, task_7, task_8]
dagrun = dag_maker.create_dagrun(
run_id=run_id,
execution_date=execution_date,
state=State.FAILED,
run_type=DagRunType.SCHEDULED,
)
def get_ti_from_db(task):
return (
session.query(TI)
.filter(
TI.dag_id == dag.dag_id,
TI.task_id == task.task_id,
TI.run_id == dagrun.run_id,
)
.one()
)
get_ti_from_db(task_1).state = State.FAILED
get_ti_from_db(task_2).state = State.SUCCESS
get_ti_from_db(task_3).state = State.UPSTREAM_FAILED
get_ti_from_db(task_4).state = State.SUCCESS
get_ti_from_db(task_5).state = State.UPSTREAM_FAILED
get_ti_from_db(task_6).state = State.FAILED
get_ti_from_db(task_7).state = State.SKIPPED
session.flush()
altered = dag.set_task_group_state(
group_id=section_1.group_id,
run_id=run_id,
execution_date=execution_date,
state=State.SUCCESS,
session=session,
)
# After _mark_task_instance_state, task_1 is marked as SUCCESS
assert get_ti_from_db(task_1).state == State.SUCCESS
# task_2 remains as SUCCESS
assert get_ti_from_db(task_2).state == State.SUCCESS
# task_3 should be marked as SUCCESS
assert get_ti_from_db(task_3).state == State.SUCCESS
# task_4 should remain as SUCCESS
assert get_ti_from_db(task_4).state == State.SUCCESS
# task_5 and task_6 are cleared because they were in FAILED/UPSTREAM_FAILED state
assert get_ti_from_db(task_5).state == State.NONE
assert get_ti_from_db(task_6).state == State.NONE
# task_7 remains as SKIPPED
assert get_ti_from_db(task_7).state == State.SKIPPED
dagrun.refresh_from_db(session=session)
# dagrun should be set to QUEUED
assert dagrun.get_state() == State.QUEUED
assert {t.key for t in altered} == {
("test_set_task_group_state", "section_1.task_1", dagrun.run_id, 1, -1),
("test_set_task_group_state", "section_1.task_3", dagrun.run_id, 1, -1),
}
def test_dag_teardowns_property_lists_all_teardown_tasks(dag_maker):
@setup
def setup_task():
return 1
@teardown
def teardown_task():
return 1
@teardown
def teardown_task2():
return 1
@teardown
def teardown_task3():
return 1
@task_decorator
def mytask():
return 1
with dag_maker() as dag:
t1 = setup_task()
t2 = teardown_task()
t3 = teardown_task2()
t4 = teardown_task3()
with t1 >> t2:
with t3:
with t4:
mytask()
assert {t.task_id for t in dag.teardowns} == {"teardown_task", "teardown_task2", "teardown_task3"}
assert {t.task_id for t in dag.tasks_upstream_of_teardowns} == {"setup_task", "mytask"}
@pytest.mark.parametrize(
"start_date, expected_infos",
[
(
DEFAULT_DATE,
[DagRunInfo.interval(DEFAULT_DATE, DEFAULT_DATE + datetime.timedelta(hours=1))],
),
(
DEFAULT_DATE - datetime.timedelta(hours=3),
[
DagRunInfo.interval(
DEFAULT_DATE - datetime.timedelta(hours=3),
DEFAULT_DATE - datetime.timedelta(hours=2),
),
DagRunInfo.interval(
DEFAULT_DATE - datetime.timedelta(hours=2),
DEFAULT_DATE - datetime.timedelta(hours=1),
),
DagRunInfo.interval(
DEFAULT_DATE - datetime.timedelta(hours=1),
DEFAULT_DATE,
),
DagRunInfo.interval(
DEFAULT_DATE,
DEFAULT_DATE + datetime.timedelta(hours=1),
),
],
),
],
ids=["in-dag-restriction", "out-of-dag-restriction"],
)
def test_iter_dagrun_infos_between(start_date, expected_infos):
dag = DAG(dag_id="test_get_dates", start_date=DEFAULT_DATE, schedule="@hourly")
EmptyOperator(task_id="dummy", dag=dag)
iterator = dag.iter_dagrun_infos_between(
earliest=pendulum.instance(start_date),
latest=pendulum.instance(DEFAULT_DATE),
align=True,
)
assert expected_infos == list(iterator)
def test_iter_dagrun_infos_between_error(caplog):
start = pendulum.instance(DEFAULT_DATE - datetime.timedelta(hours=1))
end = pendulum.instance(DEFAULT_DATE)
class FailingAfterOneTimetable(Timetable):
def next_dagrun_info(self, last_automated_data_interval, restriction):
if last_automated_data_interval is None:
return DagRunInfo.interval(start, end)
raise RuntimeError("this fails")
dag = DAG(
dag_id="test_iter_dagrun_infos_between_error",
start_date=DEFAULT_DATE,
timetable=FailingAfterOneTimetable(),
)
iterator = dag.iter_dagrun_infos_between(earliest=start, latest=end, align=True)
with caplog.at_level(logging.ERROR):
infos = list(iterator)
# The second timetable.next_dagrun_info() call raises an exception, so only the first result is returned.
assert infos == [DagRunInfo.interval(start, end)]
assert caplog.record_tuples == [
(
"airflow.models.dag.DAG",
logging.ERROR,
f"Failed to fetch run info after data interval {DataInterval(start, end)} for DAG {dag.dag_id!r}",
),
]
assert caplog.records[0].exc_info is not None, "should contain exception context"
@pytest.mark.parametrize(
"logical_date, data_interval_start, data_interval_end, expected_data_interval",
[
pytest.param(None, None, None, None, id="no-next-run"),
pytest.param(
DEFAULT_DATE,
DEFAULT_DATE,
DEFAULT_DATE + timedelta(days=2),
DataInterval(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=2)),
id="modern",
),
pytest.param(
DEFAULT_DATE,
None,
None,
DataInterval(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=1)),
id="legacy",
),
],
)
def test_get_next_data_interval(
logical_date,
data_interval_start,
data_interval_end,
expected_data_interval,
):
dag = DAG(dag_id="test_get_next_data_interval", schedule="@daily")
dag_model = DagModel(
dag_id="test_get_next_data_interval",
next_dagrun=logical_date,
next_dagrun_data_interval_start=data_interval_start,
next_dagrun_data_interval_end=data_interval_end,
)
assert dag.get_next_data_interval(dag_model) == expected_data_interval
@pytest.mark.parametrize(
("dag_date", "tasks_date", "restrict"),
[
[
(DEFAULT_DATE, None),
[
(DEFAULT_DATE + timedelta(days=1), DEFAULT_DATE + timedelta(days=2)),
(DEFAULT_DATE + timedelta(days=3), DEFAULT_DATE + timedelta(days=4)),
],
TimeRestriction(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=4), True),
],
[
(DEFAULT_DATE, None),
[(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=1)), (DEFAULT_DATE, None)],
TimeRestriction(DEFAULT_DATE, None, True),
],
],
)
def test__time_restriction(dag_maker, dag_date, tasks_date, restrict):
with dag_maker("test__time_restriction", start_date=dag_date[0], end_date=dag_date[1]) as dag:
EmptyOperator(task_id="do1", start_date=tasks_date[0][0], end_date=tasks_date[0][1])
EmptyOperator(task_id="do2", start_date=tasks_date[1][0], end_date=tasks_date[1][1])
assert dag._time_restriction == restrict
@pytest.mark.parametrize(
"tags, should_pass",
[
pytest.param([], True, id="empty tags"),
pytest.param(["a normal tag"], True, id="one tag"),
pytest.param(["a normal tag", "another normal tag"], True, id="two tags"),
pytest.param(["a" * 100], True, id="a tag that's of just length 100"),
pytest.param(["a normal tag", "a" * 101], False, id="two tags and one of them is of length > 100"),
],
)
def test__tags_length(tags: list[str], should_pass: bool):
if should_pass:
models.DAG("test-dag", tags=tags)
else:
with pytest.raises(AirflowException):
models.DAG("test-dag", tags=tags)
@pytest.mark.need_serialized_dag
def test_get_dataset_triggered_next_run_info(dag_maker, clear_datasets):
dataset1 = Dataset(uri="ds1")
dataset2 = Dataset(uri="ds2")
dataset3 = Dataset(uri="ds3")
with dag_maker(dag_id="datasets-1", schedule=[dataset2]):
pass
dag1 = dag_maker.dag
with dag_maker(dag_id="datasets-2", schedule=[dataset1, dataset2]):
pass
dag2 = dag_maker.dag
with dag_maker(dag_id="datasets-3", schedule=[dataset1, dataset2, dataset3]):
pass
dag3 = dag_maker.dag
session = dag_maker.session
ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar()
session.bulk_save_objects(
[
DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag2.dag_id),
DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag3.dag_id),
]
)
session.flush()
datasets = session.query(DatasetModel.uri).order_by(DatasetModel.id).all()
info = get_dataset_triggered_next_run_info([dag1.dag_id], session=session)
assert info[dag1.dag_id] == {
"ready": 0,
"total": 1,
"uri": datasets[0].uri,
}
# This time, check both dag2 and dag3 at the same time (tests filtering)
info = get_dataset_triggered_next_run_info([dag2.dag_id, dag3.dag_id], session=session)
assert info[dag2.dag_id] == {
"ready": 1,
"total": 2,
"uri": "",
}
assert info[dag3.dag_id] == {
"ready": 1,
"total": 3,
"uri": "",
}
def test_dag_uses_timetable_for_run_id(session):
class CustomRunIdTimetable(Timetable):
def generate_run_id(self, *, run_type, logical_date, data_interval, **extra) -> str:
return "abc"
dag = DAG(dag_id="test", start_date=DEFAULT_DATE, schedule=CustomRunIdTimetable())
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
state=DagRunState.QUEUED,
execution_date=DEFAULT_DATE,
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
)
assert dag_run.run_id == "abc"
@pytest.mark.parametrize(
"run_id_type",
[DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED, DagRunType.DATASET_TRIGGERED],
)
def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagRunType) -> None:
dag = DAG(dag_id="test", start_date=DEFAULT_DATE, schedule="@daily")
run_id = run_id_type.generate_run_id(DEFAULT_DATE)
with pytest.raises(ValueError) as ctx:
dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
execution_date=DEFAULT_DATE,
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
state=DagRunState.QUEUED,
)
assert str(ctx.value) == (
f"A manual DAG run cannot use ID {run_id!r} since it is reserved for {run_id_type.value} runs"
)