blob: 1644c1930f9aec0dcf65a69b8b1faa7965bbc8f7 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import io
import logging
import os
import pickle
import re
import unittest
from contextlib import redirect_stdout
from datetime import timedelta
from tempfile import NamedTemporaryFile
from typing import Optional
from unittest import mock
from unittest.mock import patch
import pendulum
import pytest
from dateutil.relativedelta import relativedelta
from freezegun import freeze_time
from parameterized import parameterized
from airflow import models, settings
from airflow.configuration import conf
from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException, DuplicateTaskIdFound
from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail, TaskInstance as TI
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import dag as dag_decorator
from airflow.models.dagparam import DagParam
from airflow.operators.bash import BashOperator
from airflow.operators.dummy import DummyOperator
from airflow.operators.subdag import SubDagOperator
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime as datetime_tz
from airflow.utils.types import DagRunType
from airflow.utils.weight_rule import WeightRule
from tests.models import DEFAULT_DATE
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.db import clear_db_dags, clear_db_runs
TEST_DATE = datetime_tz(2015, 1, 2, 0, 0)
class TestDag(unittest.TestCase):
def setUp(self) -> None:
clear_db_runs()
clear_db_dags()
self.patcher_dag_code = patch.object(settings, "STORE_DAG_CODE", False)
self.patcher_dag_code.start()
def tearDown(self) -> None:
clear_db_runs()
clear_db_dags()
self.patcher_dag_code.stop()
@staticmethod
def _clean_up(dag_id: str):
with create_session() as session:
session.query(DagRun).filter(DagRun.dag_id == dag_id).delete(synchronize_session=False)
session.query(TI).filter(TI.dag_id == dag_id).delete(synchronize_session=False)
session.query(TaskFail).filter(TaskFail.dag_id == dag_id).delete(synchronize_session=False)
@staticmethod
def _occur_before(a, b, list_):
"""
Assert that a occurs before b in the list.
"""
a_index = -1
b_index = -1
for i, e in enumerate(list_):
if e.task_id == a:
a_index = i
if e.task_id == b:
b_index = i
return 0 <= a_index < b_index
def test_params_not_passed_is_empty_dict(self):
"""
Test that when 'params' is _not_ passed to a new Dag, that the params
attribute is set to an empty dictionary.
"""
dag = models.DAG('test-dag')
assert isinstance(dag.params, dict)
assert 0 == len(dag.params)
def test_params_passed_and_params_in_default_args_no_override(self):
"""
Test that when 'params' exists as a key passed to the default_args dict
in addition to params being passed explicitly as an argument to the
dag, that the 'params' key of the default_args dict is merged with the
dict of the params argument.
"""
params1 = {'parameter1': 1}
params2 = {'parameter2': 2}
dag = models.DAG('test-dag', default_args={'params': params1}, params=params2)
params_combined = params1.copy()
params_combined.update(params2)
assert params_combined == dag.params
def test_dag_invalid_default_view(self):
"""
Test invalid `default_view` of DAG initialization
"""
with pytest.raises(AirflowException, match='Invalid values of dag.default_view: only support'):
models.DAG(dag_id='test-invalid-default_view', default_view='airflow')
def test_dag_default_view_default_value(self):
"""
Test `default_view` default value of DAG initialization
"""
dag = models.DAG(dag_id='test-default_default_view')
assert conf.get('webserver', 'dag_default_view').lower() == dag.default_view
def test_dag_invalid_orientation(self):
"""
Test invalid `orientation` of DAG initialization
"""
with pytest.raises(AirflowException, match='Invalid values of dag.orientation: only support'):
models.DAG(dag_id='test-invalid-orientation', orientation='airflow')
def test_dag_orientation_default_value(self):
"""
Test `orientation` default value of DAG initialization
"""
dag = models.DAG(dag_id='test-default_orientation')
assert conf.get('webserver', 'dag_orientation') == dag.orientation
def test_dag_as_context_manager(self):
"""
Test DAG as a context manager.
When used as a context manager, Operators are automatically added to
the DAG (unless they specify a different DAG)
"""
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
dag2 = DAG('dag2', start_date=DEFAULT_DATE, default_args={'owner': 'owner2'})
with dag:
op1 = DummyOperator(task_id='op1')
op2 = DummyOperator(task_id='op2', dag=dag2)
assert op1.dag is dag
assert op1.owner == 'owner1'
assert op2.dag is dag2
assert op2.owner == 'owner2'
with dag2:
op3 = DummyOperator(task_id='op3')
assert op3.dag is dag2
assert op3.owner == 'owner2'
with dag:
with dag2:
op4 = DummyOperator(task_id='op4')
op5 = DummyOperator(task_id='op5')
assert op4.dag is dag2
assert op5.dag is dag
assert op4.owner == 'owner2'
assert op5.owner == 'owner1'
with DAG('creating_dag_in_cm', start_date=DEFAULT_DATE) as dag:
DummyOperator(task_id='op6')
assert dag.dag_id == 'creating_dag_in_cm'
assert dag.tasks[0].task_id == 'op6'
with dag:
with dag:
op7 = DummyOperator(task_id='op7')
op8 = DummyOperator(task_id='op8')
op9 = DummyOperator(task_id='op8')
op9.dag = dag2
assert op7.dag == dag
assert op8.dag == dag
assert op9.dag == dag2
def test_dag_topological_sort_include_subdag_tasks(self):
child_dag = DAG(
'parent_dag.child_dag',
schedule_interval='@daily',
start_date=DEFAULT_DATE,
)
with child_dag:
DummyOperator(task_id='a_child')
DummyOperator(task_id='b_child')
parent_dag = DAG(
'parent_dag',
schedule_interval='@daily',
start_date=DEFAULT_DATE,
)
# a_parent -> child_dag -> (a_child | b_child) -> b_parent
with parent_dag:
op1 = DummyOperator(task_id='a_parent')
op2 = SubDagOperator(task_id='child_dag', subdag=child_dag)
op3 = DummyOperator(task_id='b_parent')
op1 >> op2 >> op3
topological_list = parent_dag.topological_sort(include_subdag_tasks=True)
assert self._occur_before('a_parent', 'child_dag', topological_list)
assert self._occur_before('child_dag', 'a_child', topological_list)
assert self._occur_before('child_dag', 'b_child', topological_list)
assert self._occur_before('a_child', 'b_parent', topological_list)
assert self._occur_before('b_child', 'b_parent', topological_list)
def test_dag_topological_sort1(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# A -> B
# A -> C -> D
# ordered: B, D, C, A or D, B, C, A or D, C, B, A
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op1.set_upstream([op2, op3])
op3.set_upstream(op4)
topological_list = dag.topological_sort()
logging.info(topological_list)
tasks = [op2, op3, op4]
assert topological_list[0] in tasks
tasks.remove(topological_list[0])
assert topological_list[1] in tasks
tasks.remove(topological_list[1])
assert topological_list[2] in tasks
tasks.remove(topological_list[2])
assert topological_list[3] == op1
def test_dag_topological_sort2(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
# C -> (A u B) -> D
# C -> E
# ordered: E | D, A | B, C
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op5 = DummyOperator(task_id='E')
op1.set_downstream(op3)
op2.set_downstream(op3)
op1.set_upstream(op4)
op2.set_upstream(op4)
op5.set_downstream(op3)
topological_list = dag.topological_sort()
logging.info(topological_list)
set1 = [op4, op5]
assert topological_list[0] in set1
set1.remove(topological_list[0])
set2 = [op1, op2]
set2.extend(set1)
assert topological_list[1] in set2
set2.remove(topological_list[1])
assert topological_list[2] in set2
set2.remove(topological_list[2])
assert topological_list[3] in set2
assert topological_list[4] == op3
def test_dag_topological_sort_dag_without_tasks(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
assert () == dag.topological_sort()
def test_dag_naive_start_date_string(self):
DAG('DAG', default_args={'start_date': '2019-06-01'})
def test_dag_naive_start_end_dates_strings(self):
DAG('DAG', default_args={'start_date': '2019-06-01', 'end_date': '2019-06-05'})
def test_dag_start_date_propagates_to_end_date(self):
"""
Tests that a start_date string with a timezone and an end_date string without a timezone
are accepted and that the timezone from the start carries over the end
This test is a little indirect, it works by setting start and end equal except for the
timezone and then testing for equality after the DAG construction. They'll be equal
only if the same timezone was applied to both.
An explicit check the `tzinfo` attributes for both are the same is an extra check.
"""
dag = DAG(
'DAG', default_args={'start_date': '2019-06-05T00:00:00+05:00', 'end_date': '2019-06-05T00:00:00'}
)
assert dag.default_args['start_date'] == dag.default_args['end_date']
assert dag.default_args['start_date'].tzinfo == dag.default_args['end_date'].tzinfo
def test_dag_naive_default_args_start_date(self):
dag = DAG('DAG', default_args={'start_date': datetime.datetime(2018, 1, 1)})
assert dag.timezone == settings.TIMEZONE
dag = DAG('DAG', start_date=datetime.datetime(2018, 1, 1))
assert dag.timezone == settings.TIMEZONE
def test_dag_none_default_args_start_date(self):
"""
Tests if a start_date of None in default_args
works.
"""
dag = DAG('DAG', default_args={'start_date': None})
assert dag.timezone == settings.TIMEZONE
def test_dag_task_priority_weight_total(self):
width = 5
depth = 5
weight = 5
pattern = re.compile('stage(\\d*).(\\d*)')
# Fully connected parallel tasks. i.e. every task at each parallel
# stage is dependent on every task in the previous stage.
# Default weight should be calculated using downstream descendants
with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag:
pipeline = [
[DummyOperator(task_id=f'stage{i}.{j}', priority_weight=weight) for j in range(0, width)]
for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for task in dag.task_dict.values():
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
calculated_weight = task.priority_weight_total
assert calculated_weight == correct_weight
def test_dag_task_priority_weight_total_using_upstream(self):
# Same test as above except use 'upstream' for weight calculation
weight = 3
width = 5
depth = 5
pattern = re.compile('stage(\\d*).(\\d*)')
with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag:
pipeline = [
[
DummyOperator(
task_id=f'stage{i}.{j}',
priority_weight=weight,
weight_rule=WeightRule.UPSTREAM,
)
for j in range(0, width)
]
for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for task in dag.task_dict.values():
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = (task_depth * width + 1) * weight
calculated_weight = task.priority_weight_total
assert calculated_weight == correct_weight
def test_dag_task_priority_weight_total_using_absolute(self):
# Same test as above except use 'absolute' for weight calculation
weight = 10
width = 5
depth = 5
with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag:
pipeline = [
[
DummyOperator(
task_id=f'stage{i}.{j}',
priority_weight=weight,
weight_rule=WeightRule.ABSOLUTE,
)
for j in range(0, width)
]
for i in range(0, depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for task in dag.task_dict.values():
# the sum of each stages after this task + itself
correct_weight = weight
calculated_weight = task.priority_weight_total
assert calculated_weight == correct_weight
def test_dag_task_invalid_weight_rule(self):
# Test if we enter an invalid weight rule
with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
with pytest.raises(AirflowException):
DummyOperator(task_id='should_fail', weight_rule='no rule')
def test_get_num_task_instances(self):
test_dag_id = 'test_get_num_task_instances_dag'
test_task_id = 'task_1'
test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE)
test_task = DummyOperator(task_id=test_task_id, dag=test_dag)
ti1 = TI(task=test_task, execution_date=DEFAULT_DATE)
ti1.state = None
ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
ti2.state = State.RUNNING
ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
ti3.state = State.QUEUED
ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3))
ti4.state = State.RUNNING
session = settings.Session()
session.merge(ti1)
session.merge(ti2)
session.merge(ti3)
session.merge(ti4)
session.commit()
assert 0 == DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session)
assert 4 == DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session)
assert 4 == DAG.get_num_task_instances(test_dag_id, ['fakename', test_task_id], session=session)
assert 1 == DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[None], session=session)
assert 2 == DAG.get_num_task_instances(
test_dag_id, [test_task_id], states=[State.RUNNING], session=session
)
assert 3 == DAG.get_num_task_instances(
test_dag_id, [test_task_id], states=[None, State.RUNNING], session=session
)
assert 4 == DAG.get_num_task_instances(
test_dag_id, [test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session
)
session.close()
def test_user_defined_filters(self):
def jinja_udf(name):
return f'Hello {name}'
dag = models.DAG('test-dag', start_date=DEFAULT_DATE, user_defined_filters={"hello": jinja_udf})
jinja_env = dag.get_template_env()
assert 'hello' in jinja_env.filters
assert jinja_env.filters['hello'] == jinja_udf
def test_resolve_template_files_value(self):
with NamedTemporaryFile(suffix='.template') as f:
f.write(b'{{ ds }}')
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)
with DAG('test-dag', start_date=DEFAULT_DATE, template_searchpath=template_dir):
task = DummyOperator(task_id='op1')
task.test_field = template_file
task.template_fields = ('test_field',)
task.template_ext = ('.template',)
task.resolve_template_files()
assert task.test_field == '{{ ds }}'
def test_resolve_template_files_list(self):
with NamedTemporaryFile(suffix='.template') as f:
f.write(b'{{ ds }}')
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)
with DAG('test-dag', start_date=DEFAULT_DATE, template_searchpath=template_dir):
task = DummyOperator(task_id='op1')
task.test_field = [template_file, 'some_string']
task.template_fields = ('test_field',)
task.template_ext = ('.template',)
task.resolve_template_files()
assert task.test_field == ['{{ ds }}', 'some_string']
def test_following_previous_schedule(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55), dst_rule=pendulum.PRE_TRANSITION)
assert start.isoformat() == "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST"
utc = timezone.convert_to_utc(start)
dag = DAG('tz_dag', start_date=start, schedule_interval='*/5 * * * *')
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
assert _next.isoformat() == "2018-10-28T01:00:00+00:00"
assert next_local.isoformat() == "2018-10-28T02:00:00+01:00"
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-28T02:50:00+02:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-28T02:55:00+02:00"
assert prev == utc
def test_following_previous_schedule_daily_dag_cest_to_cet(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
start = local_tz.convert(datetime.datetime(2018, 10, 27, 3), dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *')
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-26T03:00:00+02:00"
assert prev.isoformat() == "2018-10-26T01:00:00+00:00"
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
assert next_local.isoformat() == "2018-10-28T03:00:00+01:00"
assert _next.isoformat() == "2018-10-28T02:00:00+00:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-10-27T03:00:00+02:00"
assert prev.isoformat() == "2018-10-27T01:00:00+00:00"
def test_following_previous_schedule_daily_dag_cet_to_cest(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
start = local_tz.convert(datetime.datetime(2018, 3, 25, 2), dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *')
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
assert next_local.isoformat() == "2018-03-25T03:00:00+02:00"
assert _next.isoformat() == "2018-03-25T01:00:00+00:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
def test_following_schedule_relativedelta(self):
"""
Tests following_schedule a dag with a relativedelta schedule_interval
"""
dag_id = "test_schedule_dag_relativedelta"
delta = relativedelta(hours=+1)
dag = DAG(dag_id=dag_id, schedule_interval=delta)
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
_next = dag.following_schedule(TEST_DATE)
assert _next.isoformat() == "2015-01-02T01:00:00+00:00"
_next = dag.following_schedule(_next)
assert _next.isoformat() == "2015-01-02T02:00:00+00:00"
def test_previous_schedule_datetime_timezone(self):
# Check that we don't get an AttributeError 'name' for self.timezone
start = datetime.datetime(2018, 3, 25, 2, tzinfo=datetime.timezone.utc)
dag = DAG('tz_dag', start_date=start, schedule_interval='@hourly')
when = dag.previous_schedule(start)
assert when.isoformat() == "2018-03-25T01:00:00+00:00"
def test_following_schedule_datetime_timezone(self):
# Check that we don't get an AttributeError 'name' for self.timezone
start = datetime.datetime(2018, 3, 25, 2, tzinfo=datetime.timezone.utc)
dag = DAG('tz_dag', start_date=start, schedule_interval='@hourly')
when = dag.following_schedule(start)
assert when.isoformat() == "2018-03-25T03:00:00+00:00"
def test_following_schedule_datetime_timezone_utc0530(self):
# Check that we don't get an AttributeError 'name' for self.timezone
class UTC0530(datetime.tzinfo):
"""tzinfo derived concrete class named "+0530" with offset of 19800"""
# can be configured here
_offset = datetime.timedelta(seconds=19800)
_dst = datetime.timedelta(0)
_name = "+0530"
def utcoffset(self, dt):
return self.__class__._offset
def dst(self, dt):
return self.__class__._dst
def tzname(self, dt):
return self.__class__._name
start = datetime.datetime(2018, 3, 25, 10, tzinfo=UTC0530())
dag = DAG('tz_dag', start_date=start, schedule_interval='@hourly')
when = dag.following_schedule(start)
assert when.isoformat() == "2018-03-25T05:30:00+00:00"
def test_dagtag_repr(self):
clear_db_dags()
dag = DAG('dag-test-dagtag', start_date=DEFAULT_DATE, tags=['tag-1', 'tag-2'])
dag.sync_to_db()
with create_session() as session:
assert {'tag-1', 'tag-2'} == {
repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == 'dag-test-dagtag').all()
}
def test_bulk_write_to_db(self):
clear_db_dags()
dags = [DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)]
with assert_queries_count(4):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert {
('dag-bulk-sync-0', 'test-dag'),
('dag-bulk-sync-1', 'test-dag'),
('dag-bulk-sync-2', 'test-dag'),
('dag-bulk-sync-3', 'test-dag'),
} == set(session.query(DagTag.dag_id, DagTag.name).all())
for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None
# Re-sync should do fewer queries
with assert_queries_count(3):
DAG.bulk_write_to_db(dags)
with assert_queries_count(3):
DAG.bulk_write_to_db(dags)
# Adding tags
for dag in dags:
dag.tags.append("test-dag2")
with assert_queries_count(4):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert {
('dag-bulk-sync-0', 'test-dag'),
('dag-bulk-sync-0', 'test-dag2'),
('dag-bulk-sync-1', 'test-dag'),
('dag-bulk-sync-1', 'test-dag2'),
('dag-bulk-sync-2', 'test-dag'),
('dag-bulk-sync-2', 'test-dag2'),
('dag-bulk-sync-3', 'test-dag'),
('dag-bulk-sync-3', 'test-dag2'),
} == set(session.query(DagTag.dag_id, DagTag.name).all())
# Removing tags
for dag in dags:
dag.tags.remove("test-dag")
with assert_queries_count(4):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert {
('dag-bulk-sync-0', 'test-dag2'),
('dag-bulk-sync-1', 'test-dag2'),
('dag-bulk-sync-2', 'test-dag2'),
('dag-bulk-sync-3', 'test-dag2'),
} == set(session.query(DagTag.dag_id, DagTag.name).all())
for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None
def test_bulk_write_to_db_max_active_runs(self):
"""
Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max
active runs being hit.
"""
dag = DAG(dag_id='test_scheduler_verify_max_active_runs', start_date=DEFAULT_DATE)
dag.max_active_runs = 1
DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
dag.clear()
DAG.bulk_write_to_db([dag], session)
model = session.query(DagModel).get((dag.dag_id,))
period_end = dag.following_schedule(DEFAULT_DATE)
assert model.next_dagrun == DEFAULT_DATE
assert model.next_dagrun_create_after == period_end
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=model.next_dagrun,
run_type=DagRunType.SCHEDULED,
session=session,
)
assert dr is not None
DAG.bulk_write_to_db([dag])
model = session.query(DagModel).get((dag.dag_id,))
assert model.next_dagrun == period_end
# Next dagrun after is not None because the dagrun would be in queued state
assert model.next_dagrun_create_after is not None
def test_sync_to_db(self):
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
DummyOperator(task_id='task', owner='owner1')
subdag = DAG(
'dag.subtask',
start_date=DEFAULT_DATE,
)
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
subdag.is_subdag = True
SubDagOperator(task_id='subtask', owner='owner2', subdag=subdag)
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
assert set(orm_dag.owners.split(', ')) == {'owner1', 'owner2'}
assert orm_dag.is_active
assert orm_dag.default_view is not None
assert orm_dag.default_view == conf.get('webserver', 'dag_default_view').lower()
assert orm_dag.safe_dag_id == 'dag'
orm_subdag = session.query(DagModel).filter(DagModel.dag_id == 'dag.subtask').one()
assert set(orm_subdag.owners.split(', ')) == {'owner1', 'owner2'}
assert orm_subdag.is_active
assert orm_subdag.safe_dag_id == 'dag__dot__subtask'
assert orm_subdag.fileloc == orm_dag.fileloc
session.close()
def test_sync_to_db_default_view(self):
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_view="graph",
)
with dag:
DummyOperator(task_id='task', owner='owner1')
SubDagOperator(
task_id='subtask',
owner='owner2',
subdag=DAG(
'dag.subtask',
start_date=DEFAULT_DATE,
),
)
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
assert orm_dag.default_view is not None
assert orm_dag.default_view == "graph"
session.close()
@provide_session
def test_is_paused_subdag(self, session):
subdag_id = 'dag.subdag'
subdag = DAG(
subdag_id,
start_date=DEFAULT_DATE,
)
with subdag:
DummyOperator(
task_id='dummy_task',
)
dag_id = 'dag'
dag = DAG(
dag_id,
start_date=DEFAULT_DATE,
)
with dag:
SubDagOperator(task_id='subdag', subdag=subdag)
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
subdag.is_subdag = True
session.query(DagModel).filter(DagModel.dag_id.in_([subdag_id, dag_id])).delete(
synchronize_session=False
)
dag.sync_to_db(session=session)
unpaused_dags = (
session.query(DagModel.dag_id, DagModel.is_paused)
.filter(
DagModel.dag_id.in_([subdag_id, dag_id]),
)
.all()
)
assert {
(dag_id, False),
(subdag_id, False),
} == set(unpaused_dags)
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True, including_subdags=False)
paused_dags = (
session.query(DagModel.dag_id, DagModel.is_paused)
.filter(
DagModel.dag_id.in_([subdag_id, dag_id]),
)
.all()
)
assert {
(dag_id, True),
(subdag_id, False),
} == set(paused_dags)
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True)
paused_dags = (
session.query(DagModel.dag_id, DagModel.is_paused)
.filter(
DagModel.dag_id.in_([subdag_id, dag_id]),
)
.all()
)
assert {
(dag_id, True),
(subdag_id, True),
} == set(paused_dags)
def test_existing_dag_is_paused_upon_creation(self):
dag = DAG('dag_paused')
dag.sync_to_db()
assert not dag.get_is_paused()
dag = DAG('dag_paused', is_paused_upon_creation=True)
dag.sync_to_db()
# Since the dag existed before, it should not follow the pause flag upon creation
assert not dag.get_is_paused()
def test_new_dag_is_paused_upon_creation(self):
dag = DAG('new_nonexisting_dag', is_paused_upon_creation=True)
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'new_nonexisting_dag').one()
# Since the dag didn't exist before, it should follow the pause flag upon creation
assert orm_dag.is_paused
session.close()
def test_existing_dag_default_view(self):
with create_session() as session:
session.add(DagModel(dag_id='dag_default_view_old', default_view=None))
session.commit()
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag_default_view_old').one()
assert orm_dag.default_view is None
assert orm_dag.get_default_view() == conf.get('webserver', 'dag_default_view').lower()
def test_dag_is_deactivated_upon_dagfile_deletion(self):
dag_id = 'old_existing_dag'
dag_fileloc = "/usr/local/airflow/dags/non_existing_path.py"
dag = DAG(
dag_id,
is_paused_upon_creation=True,
)
dag.fileloc = dag_fileloc
session = settings.Session()
with mock.patch.object(settings, "STORE_DAG_CODE", False):
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
assert orm_dag.is_active
assert orm_dag.fileloc == dag_fileloc
DagModel.deactivate_deleted_dags(list_py_file_paths(settings.DAGS_FOLDER))
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
assert not orm_dag.is_active
session.execute(DagModel.__table__.delete().where(DagModel.dag_id == dag_id))
session.close()
def test_dag_naive_default_args_start_date_with_timezone(self):
local_tz = pendulum.timezone('Europe/Zurich')
default_args = {'start_date': datetime.datetime(2018, 1, 1, tzinfo=local_tz)}
dag = DAG('DAG', default_args=default_args)
assert dag.timezone.name == local_tz.name
dag = DAG('DAG', default_args=default_args)
assert dag.timezone.name == local_tz.name
def test_roots(self):
"""Verify if dag.roots returns the root tasks of a DAG."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = DummyOperator(task_id="t1")
op2 = DummyOperator(task_id="t2")
op3 = DummyOperator(task_id="t3")
op4 = DummyOperator(task_id="t4")
op5 = DummyOperator(task_id="t5")
[op1, op2] >> op3 >> [op4, op5]
assert set(dag.roots) == {op1, op2}
def test_leaves(self):
"""Verify if dag.leaves returns the leaf tasks of a DAG."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = DummyOperator(task_id="t1")
op2 = DummyOperator(task_id="t2")
op3 = DummyOperator(task_id="t3")
op4 = DummyOperator(task_id="t4")
op5 = DummyOperator(task_id="t5")
[op1, op2] >> op3 >> [op4, op5]
assert set(dag.leaves) == {op4, op5}
def test_tree_view(self):
"""Verify correctness of dag.tree_view()."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = DummyOperator(task_id="t1")
op2 = DummyOperator(task_id="t2")
op3 = DummyOperator(task_id="t3")
op1 >> op2 >> op3
with redirect_stdout(io.StringIO()) as stdout:
dag.tree_view()
stdout = stdout.getvalue()
stdout_lines = stdout.split("\n")
assert 't1' in stdout_lines[0]
assert 't2' in stdout_lines[1]
assert 't3' in stdout_lines[2]
def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = DummyOperator(task_id="t1")
op2 = BashOperator(task_id="t1", bash_command="sleep 1")
op1 >> op2
assert dag.task_dict == {op1.task_id: op1}
def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"):
dag = DAG("test_dag", start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id="t1", dag=dag)
op2 = DummyOperator(task_id="t1", dag=dag)
op1 >> op2
assert dag.task_dict == {op1.task_id: op1}
def test_duplicate_task_ids_for_same_task_is_allowed(self):
"""Verify that same tasks with Duplicate task_id do not raise error"""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = op2 = DummyOperator(task_id="t1")
op3 = DummyOperator(task_id="t3")
op1 >> op3
op2 >> op3
assert op1 == op2
assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
def test_sub_dag_updates_all_references_while_deepcopy(self):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = DummyOperator(task_id='t1')
op2 = DummyOperator(task_id='t2')
op3 = DummyOperator(task_id='t3')
op1 >> op2
op2 >> op3
sub_dag = dag.sub_dag('t2', include_upstream=True, include_downstream=False)
assert id(sub_dag.task_dict['t1'].downstream_list[0].dag) == id(sub_dag)
# Copied DAG should not include unused task IDs in used_group_ids
assert 't3' not in sub_dag._task_group.used_group_ids
def test_schedule_dag_no_previous_runs(self):
"""
Tests scheduling a dag with no previous runs
"""
dag_id = "test_schedule_dag_no_previous_runs"
dag = DAG(dag_id=dag_id)
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
dag_run = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=TEST_DATE,
state=State.RUNNING,
)
assert dag_run is not None
assert dag.dag_id == dag_run.dag_id
assert dag_run.run_id is not None
assert '' != dag_run.run_id
assert (
TEST_DATE == dag_run.execution_date
), f'dag_run.execution_date did not match expectation: {dag_run.execution_date}'
assert State.RUNNING == dag_run.state
assert not dag_run.external_trigger
dag.clear()
self._clean_up(dag_id)
@patch('airflow.models.dag.Stats')
def test_dag_handle_callback_crash(self, mock_stats):
"""
Tests avoid crashes from calling dag callbacks exceptions
"""
dag_id = "test_dag_callback_crash"
mock_callback_with_exception = mock.MagicMock()
mock_callback_with_exception.side_effect = Exception
dag = DAG(
dag_id=dag_id,
# callback with invalid signature should not cause crashes
on_success_callback=lambda: 1,
on_failure_callback=mock_callback_with_exception,
)
when = TEST_DATE
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=when))
dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL)
# should not raise any exception
dag.handle_callback(dag_run, success=False)
dag.handle_callback(dag_run, success=True)
mock_stats.incr.assert_called_with("dag.callback_exceptions")
dag.clear()
self._clean_up(dag_id)
def test_next_dagrun_after_fake_scheduled_previous(self):
"""
Test scheduling a dag where there is a prior DagRun
which has the same run_id as the next run should have
"""
delta = datetime.timedelta(hours=1)
dag_id = "test_schedule_dag_fake_scheduled_previous"
dag = DAG(dag_id=dag_id, schedule_interval=delta, start_date=DEFAULT_DATE)
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=DEFAULT_DATE))
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE,
state=State.SUCCESS,
external_trigger=True,
)
dag.sync_to_db()
with create_session() as session:
model = session.query(DagModel).get((dag.dag_id,))
# Even though there is a run for this date already, it is marked as manual/external, so we should
# create a scheduled one anyway!
assert model.next_dagrun == DEFAULT_DATE
assert model.next_dagrun_create_after == dag.following_schedule(DEFAULT_DATE)
self._clean_up(dag_id)
def test_schedule_dag_once(self):
"""
Tests scheduling a dag scheduled for @once - should be scheduled the first time
it is called, and not scheduled the second.
"""
dag_id = "test_schedule_dag_once"
dag = DAG(dag_id=dag_id)
dag.schedule_interval = '@once'
assert dag.normalized_schedule_interval is None
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
# Sync once to create the DagModel
dag.sync_to_db()
dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=TEST_DATE, state=State.SUCCESS)
# Then sync again after creating the dag run -- this should update next_dagrun
dag.sync_to_db()
with create_session() as session:
model = session.query(DagModel).get((dag.dag_id,))
assert model.next_dagrun is None
assert model.next_dagrun_create_after is None
self._clean_up(dag_id)
def test_fractional_seconds(self):
"""
Tests if fractional seconds are stored in the database
"""
dag_id = "test_fractional_seconds"
dag = DAG(dag_id=dag_id)
dag.schedule_interval = '@once'
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
start_date = timezone.utcnow()
run = dag.create_dagrun(
run_id='test_' + start_date.isoformat(),
execution_date=start_date,
start_date=start_date,
state=State.RUNNING,
external_trigger=False,
)
run.refresh_from_db()
assert start_date == run.execution_date, "dag run execution_date loses precision"
assert start_date == run.start_date, "dag run start_date loses precision "
self._clean_up(dag_id)
def test_pickling(self):
test_dag_id = 'test_pickling'
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(test_dag_id, default_args=args)
dag_pickle = dag.pickle()
assert dag_pickle.pickle.dag_id == dag.dag_id
def test_rich_comparison_ops(self):
test_dag_id = 'test_rich_comparison_ops'
class DAGsubclass(DAG):
pass
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(test_dag_id, default_args=args)
dag_eq = DAG(test_dag_id, default_args=args)
dag_diff_load_time = DAG(test_dag_id, default_args=args)
dag_diff_name = DAG(test_dag_id + '_neq', default_args=args)
dag_subclass = DAGsubclass(test_dag_id, default_args=args)
dag_subclass_diff_name = DAGsubclass(test_dag_id + '2', default_args=args)
for dag_ in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]:
dag_.last_loaded = dag.last_loaded
# test identity equality
assert dag == dag
# test dag (in)equality based on _comps
assert dag_eq == dag
assert dag_diff_name != dag
assert dag_diff_load_time != dag
# test dag inequality based on type even if _comps happen to match
assert dag_subclass != dag
# a dag should equal an unpickled version of itself
dump = pickle.dumps(dag)
assert pickle.loads(dump) == dag
# dags are ordered based on dag_id no matter what the type is
assert dag < dag_diff_name
assert dag > dag_diff_load_time
assert dag < dag_subclass_diff_name
# greater than should have been created automatically by functools
assert dag_diff_name > dag
# hashes are non-random and match equality
assert hash(dag) == hash(dag)
assert hash(dag_eq) == hash(dag)
assert hash(dag_diff_name) != hash(dag)
assert hash(dag_subclass) != hash(dag)
def test_get_paused_dag_ids(self):
dag_id = "test_get_paused_dag_ids"
dag = DAG(dag_id, is_paused_upon_creation=True)
dag.sync_to_db()
assert DagModel.get_dagmodel(dag_id) is not None
paused_dag_ids = DagModel.get_paused_dag_ids([dag_id])
assert paused_dag_ids == {dag_id}
with create_session() as session:
session.query(DagModel).filter(DagModel.dag_id == dag_id).delete(synchronize_session=False)
@parameterized.expand(
[
(None, None),
("@daily", "0 0 * * *"),
("@weekly", "0 0 * * 0"),
("@monthly", "0 0 1 * *"),
("@quarterly", "0 0 1 */3 *"),
("@yearly", "0 0 1 1 *"),
("@once", None),
(datetime.timedelta(days=1), datetime.timedelta(days=1)),
]
)
def test_normalized_schedule_interval(self, schedule_interval, expected_n_schedule_interval):
dag = DAG("test_schedule_interval", schedule_interval=schedule_interval)
assert dag.normalized_schedule_interval == expected_n_schedule_interval
assert dag.schedule_interval == schedule_interval
def test_set_dag_runs_state(self):
clear_db_runs()
dag_id = "test_set_dag_runs_state"
dag = DAG(dag_id=dag_id)
for i in range(3):
dag.create_dagrun(run_id=f"test{i}", state=State.RUNNING)
dag.set_dag_runs_state(state=State.NONE)
drs = DagRun.find(dag_id=dag_id)
assert len(drs) == 3
assert all(dr.state == State.NONE for dr in drs)
def test_create_dagrun_run_id_is_generated(self):
dag = DAG(dag_id="run_id_is_generated")
dr = dag.create_dagrun(run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE, state=State.NONE)
assert dr.run_id == f"manual__{DEFAULT_DATE.isoformat()}"
def test_create_dagrun_run_type_is_obtained_from_run_id(self):
dag = DAG(dag_id="run_type_is_obtained_from_run_id")
dr = dag.create_dagrun(run_id="scheduled__", state=State.NONE)
assert dr.run_type == DagRunType.SCHEDULED
dr = dag.create_dagrun(run_id="custom_is_set_to_manual", state=State.NONE)
assert dr.run_type == DagRunType.MANUAL
def test_create_dagrun_job_id_is_set(self):
job_id = 42
dag = DAG(dag_id="test_create_dagrun_job_id_is_set")
dr = dag.create_dagrun(
run_id="test_create_dagrun_job_id_is_set", state=State.NONE, creating_job_id=job_id
)
assert dr.creating_job_id == job_id
@parameterized.expand(
[
(State.NONE,),
(State.RUNNING,),
]
)
def test_clear_set_dagrun_state(self, dag_run_state):
dag_id = 'test_clear_set_dagrun_state'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = DummyOperator(task_id=task_id, dag=dag)
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)
task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING)
session.merge(task_instance_1)
session.commit()
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,
include_subdags=False,
include_parentdag=False,
session=session,
)
dagruns = (
session.query(
DagRun,
)
.filter(
DagRun.dag_id == dag_id,
)
.all()
)
assert len(dagruns) == 1
dagrun = dagruns[0] # type: DagRun
assert dagrun.state == dag_run_state
@parameterized.expand(
[
(State.NONE,),
(State.RUNNING,),
]
)
def test_clear_set_dagrun_state_for_subdag(self, dag_run_state):
dag_id = 'test_clear_set_dagrun_state_subdag'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = DummyOperator(task_id=task_id, dag=dag)
subdag = DAG(dag_id + '.test', start_date=DEFAULT_DATE, max_active_runs=1)
SubDagOperator(task_id='test', subdag=subdag, dag=dag)
t_2 = DummyOperator(task_id='task', dag=subdag)
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
dagrun_2 = subdag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)
session.merge(dagrun_2)
task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING)
task_instance_2 = TI(t_2, execution_date=DEFAULT_DATE, state=State.RUNNING)
session.merge(task_instance_1)
session.merge(task_instance_2)
session.commit()
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,
include_subdags=True,
include_parentdag=False,
session=session,
)
dagrun = (
session.query(
DagRun,
)
.filter(DagRun.dag_id == subdag.dag_id)
.one()
)
assert dagrun.state == dag_run_state
@parameterized.expand(
[
(State.NONE,),
(State.RUNNING,),
]
)
def test_clear_set_dagrun_state_for_parent_dag(self, dag_run_state):
dag_id = 'test_clear_set_dagrun_state_parent_dag'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = DummyOperator(task_id=task_id, dag=dag)
subdag = DAG(dag_id + '.test', start_date=DEFAULT_DATE, max_active_runs=1)
SubDagOperator(task_id='test', subdag=subdag, dag=dag)
t_2 = DummyOperator(task_id='task', dag=subdag)
subdag.parent_dag = dag
subdag.is_subdag = True
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
dagrun_2 = subdag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)
session.merge(dagrun_2)
task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING)
task_instance_2 = TI(t_2, execution_date=DEFAULT_DATE, state=State.RUNNING)
session.merge(task_instance_1)
session.merge(task_instance_2)
session.commit()
subdag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,
include_subdags=True,
include_parentdag=True,
session=session,
)
dagrun = (
session.query(
DagRun,
)
.filter(DagRun.dag_id == dag_id)
.one()
)
assert dagrun.state == dag_run_state
@parameterized.expand(
[(state, State.NONE) for state in State.task_states if state != State.RUNNING]
+ [(State.RUNNING, State.SHUTDOWN)]
) # type: ignore
def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
dag_id = 'test_clear_dag'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = DummyOperator(task_id=task_id, dag=dag)
session = settings.Session() # type: ignore
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.RUNNING,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)
task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=ti_state_begin)
task_instance_1.job_id = 123
session.merge(task_instance_1)
session.commit()
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
session=session,
)
task_instances = (
session.query(
TI,
)
.filter(
TI.dag_id == dag_id,
)
.all()
)
assert len(task_instances) == 1
task_instance = task_instances[0] # type: TI
assert task_instance.state == ti_state_end
self._clean_up(dag_id)
def test_next_dagrun_after_date_once(self):
dag = DAG(
'test_scheduler_dagrun_once', start_date=timezone.datetime(2015, 1, 1), schedule_interval="@once"
)
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2015, 1, 1)
next_date = dag.next_dagrun_after_date(next_date)
assert next_date is None
def test_next_dagrun_after_date_start_end_dates(self):
"""
Tests that an attempt to schedule a task after the Dag's end_date
does not succeed.
"""
delta = datetime.timedelta(hours=1)
runs = 3
start_date = DEFAULT_DATE
end_date = start_date + (runs - 1) * delta
dag_id = "test_schedule_dag_start_end_dates"
dag = DAG(dag_id=dag_id, start_date=start_date, end_date=end_date, schedule_interval=delta)
dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake'))
# Create and schedule the dag runs
dates = []
date = None
for _ in range(runs):
date = dag.next_dagrun_after_date(date)
dates.append(date)
for date in dates:
assert date is not None
assert dates[-1] == end_date
assert dag.next_dagrun_after_date(date) is None
def test_next_dagrun_after_date_catcup(self):
"""
Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date
"""
def make_dag(dag_id, schedule_interval, start_date, catchup):
default_args = {
'owner': 'airflow',
'depends_on_past': False,
}
dag = DAG(
dag_id,
schedule_interval=schedule_interval,
start_date=start_date,
catchup=catchup,
default_args=default_args,
)
op1 = DummyOperator(task_id='t1', dag=dag)
op2 = DummyOperator(task_id='t2', dag=dag)
op3 = DummyOperator(task_id='t3', dag=dag)
op1 >> op2 >> op3
return dag
now = timezone.utcnow()
six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace(
minute=0, second=0, microsecond=0
)
half_an_hour_ago = now - datetime.timedelta(minutes=30)
two_hours_ago = now - datetime.timedelta(hours=2)
dag1 = make_dag(
dag_id='dag_without_catchup_ten_minute',
schedule_interval='*/10 * * * *',
start_date=six_hours_ago_to_the_hour,
catchup=False,
)
next_date = dag1.next_dagrun_after_date(None)
# The DR should be scheduled in the last half an hour, not 6 hours ago
assert next_date > half_an_hour_ago
assert next_date < timezone.utcnow()
dag2 = make_dag(
dag_id='dag_without_catchup_hourly',
schedule_interval='@hourly',
start_date=six_hours_ago_to_the_hour,
catchup=False,
)
next_date = dag2.next_dagrun_after_date(None)
# The DR should be scheduled in the last 2 hours, not 6 hours ago
assert next_date > two_hours_ago
# The DR should be scheduled BEFORE now
assert next_date < timezone.utcnow()
dag3 = make_dag(
dag_id='dag_without_catchup_once',
schedule_interval='@once',
start_date=six_hours_ago_to_the_hour,
catchup=False,
)
next_date = dag3.next_dagrun_after_date(None)
# The DR should be scheduled in the last 2 hours, not 6 hours ago
assert next_date == six_hours_ago_to_the_hour
@freeze_time(timezone.datetime(2020, 1, 5))
def test_next_dagrun_after_date_timedelta_schedule_and_catchup_false(self):
"""
Test that the dag file processor does not create multiple dagruns
if a dag is scheduled with 'timedelta' and catchup=False
"""
dag = DAG(
'test_scheduler_dagrun_once_with_timedelta_and_catchup_false',
start_date=timezone.datetime(2015, 1, 1),
schedule_interval=timedelta(days=1),
catchup=False,
)
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2020, 1, 4)
# The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns"
next_date = dag.next_dagrun_after_date(next_date)
assert next_date == timezone.datetime(2020, 1, 5)
@freeze_time(timezone.datetime(2020, 5, 4))
def test_next_dagrun_after_date_timedelta_schedule_and_catchup_true(self):
"""
Test that the dag file processor creates multiple dagruns
if a dag is scheduled with 'timedelta' and catchup=True
"""
dag = DAG(
'test_scheduler_dagrun_once_with_timedelta_and_catchup_true',
start_date=timezone.datetime(2020, 5, 1),
schedule_interval=timedelta(days=1),
catchup=True,
)
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2020, 5, 1)
next_date = dag.next_dagrun_after_date(next_date)
assert next_date == timezone.datetime(2020, 5, 2)
next_date = dag.next_dagrun_after_date(next_date)
assert next_date == timezone.datetime(2020, 5, 3)
# The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns"
next_date = dag.next_dagrun_after_date(next_date)
assert next_date == timezone.datetime(2020, 5, 4)
def test_next_dagrun_after_auto_align(self):
"""
Test if the schedule_interval will be auto aligned with the start_date
such that if the start_date coincides with the schedule the first
execution_date will be start_date, otherwise it will be start_date +
interval.
"""
dag = DAG(
dag_id='test_scheduler_auto_align_1',
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
schedule_interval="4 5 * * *",
)
DummyOperator(task_id='dummy', dag=dag, owner='airflow')
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2016, 1, 2, 5, 4)
dag = DAG(
dag_id='test_scheduler_auto_align_2',
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
schedule_interval="10 10 * * *",
)
DummyOperator(task_id='dummy', dag=dag, owner='airflow')
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2016, 1, 1, 10, 10)
def test_next_dagrun_after_not_for_subdags(self):
"""
Test the subdags are never marked to have dagruns created, as they are
handled by the SubDagOperator, not the scheduler
"""
def subdag(parent_dag_name, child_dag_name, args):
"""
Create a subdag.
"""
dag_subdag = DAG(
dag_id=f'{parent_dag_name}.{child_dag_name}',
schedule_interval="@daily",
default_args=args,
)
for i in range(2):
DummyOperator(task_id=f'{child_dag_name}-task-{i + 1}', dag=dag_subdag)
return dag_subdag
with DAG(
dag_id='test_subdag_operator',
start_date=datetime.datetime(2019, 1, 1),
max_active_runs=1,
schedule_interval=timedelta(minutes=1),
) as dag:
section_1 = SubDagOperator(
task_id='section-1',
subdag=subdag(dag.dag_id, 'section-1', {'start_date': dag.start_date}),
)
subdag = section_1.subdag
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
subdag.is_subdag = True
next_date = dag.next_dagrun_after_date(None)
assert next_date == timezone.datetime(2019, 1, 1, 0, 0)
next_subdag_date = subdag.next_dagrun_after_date(None)
assert next_subdag_date is None, "SubDags should never have DagRuns created by the scheduler"
def test_replace_outdated_access_control_actions(self):
outdated_permissions = {
'role1': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
'role2': {permissions.DEPRECATED_ACTION_CAN_DAG_READ, permissions.DEPRECATED_ACTION_CAN_DAG_EDIT},
}
updated_permissions = {
'role1': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
'role2': {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT},
}
with pytest.warns(DeprecationWarning):
dag = DAG(dag_id='dag_with_outdated_perms', access_control=outdated_permissions)
assert dag.access_control == updated_permissions
with pytest.warns(DeprecationWarning):
dag.access_control = outdated_permissions
assert dag.access_control == updated_permissions
class TestDagModel:
def test_dags_needing_dagruns_not_too_early(self):
dag = DAG(dag_id='far_future_dag', start_date=timezone.datetime(2038, 1, 1))
DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
orm_dag = DagModel(
dag_id=dag.dag_id,
concurrency=1,
has_task_concurrency_limits=False,
next_dagrun=dag.start_date,
next_dagrun_create_after=timezone.datetime(2038, 1, 2),
is_active=True,
)
session.add(orm_dag)
session.flush()
dag_models = DagModel.dags_needing_dagruns(session).all()
assert dag_models == []
session.rollback()
session.close()
def test_max_active_runs_not_none(self):
dag = DAG(dag_id='test_max_active_runs_not_none', start_date=timezone.datetime(2038, 1, 1))
DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
orm_dag = DagModel(
dag_id=dag.dag_id,
has_task_concurrency_limits=False,
next_dagrun=None,
next_dagrun_create_after=None,
is_active=True,
)
session.add(orm_dag)
session.flush()
assert orm_dag.max_active_runs is not None
session.rollback()
session.close()
def test_dags_needing_dagruns_only_unpaused(self):
"""
We should never create dagruns for unpaused DAGs
"""
dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
DummyOperator(task_id='dummy', dag=dag, owner='airflow')
session = settings.Session()
orm_dag = DagModel(
dag_id=dag.dag_id,
has_task_concurrency_limits=False,
next_dagrun=dag.start_date,
next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE),
is_active=True,
)
session.add(orm_dag)
session.flush()
needed = DagModel.dags_needing_dagruns(session).all()
assert needed == [orm_dag]
orm_dag.is_paused = True
session.flush()
dag_models = DagModel.dags_needing_dagruns(session).all()
assert dag_models == []
session.rollback()
session.close()
class TestQueries(unittest.TestCase):
def setUp(self) -> None:
clear_db_runs()
def tearDown(self) -> None:
clear_db_runs()
@parameterized.expand(
[
(3,),
(12,),
]
)
def test_count_number_queries(self, tasks_count):
dag = DAG('test_dagrun_query_count', start_date=DEFAULT_DATE)
for i in range(tasks_count):
DummyOperator(task_id=f'dummy_task_{i}', owner='test', dag=dag)
with assert_queries_count(2):
dag.create_dagrun(
run_id="test_dagrun_query_count",
state=State.RUNNING,
execution_date=TEST_DATE,
)
class TestDagDecorator(unittest.TestCase):
DEFAULT_ARGS = {
"owner": "test",
"depends_on_past": True,
"start_date": timezone.utcnow(),
"retries": 1,
"retry_delay": timedelta(minutes=1),
}
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
VALUE = 42
def setUp(self):
super().setUp()
self.operator = None
def tearDown(self):
super().tearDown()
clear_db_runs()
def test_set_dag_id(self):
"""Test that checks you can set dag_id from decorator."""
@dag_decorator('test', default_args=self.DEFAULT_ARGS)
def noop_pipeline():
@task_decorator
def return_num(num):
return num
return_num(4)
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, 'test'
def test_default_dag_id(self):
"""Test that @dag uses function name as default dag id."""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
@task_decorator
def return_num(num):
return num
return_num(4)
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, 'noop_pipeline'
def test_documentation_added(self):
"""Test that @dag uses function docs as doc_md for DAG object"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
"""
Regular DAG documentation
"""
@task_decorator
def return_num(num):
return num
return_num(4)
dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, 'test'
assert dag.doc_md.strip(), "Regular DAG documentation"
def test_fails_if_arg_not_set(self):
"""Test that @dag decorated function fails if positional argument is not set"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline(value):
@task_decorator
def return_num(num):
return num
return_num(value)
# Test that if arg is not passed it raises a type error as expected.
with pytest.raises(TypeError):
noop_pipeline()
def test_dag_param_resolves(self):
"""Test that dag param is correctly resolved by operator"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num
xcom_arg = return_num(value)
self.operator = xcom_arg.operator
dag = xcom_pass_to_op()
dr = dag.create_dagrun(
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
state=State.RUNNING,
)
self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull() == self.VALUE
def test_dag_param_dagrun_parameterized(self):
"""Test that dag param is correctly overwritten when set in dag run"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num
assert isinstance(value, DagParam)
xcom_arg = return_num(value)
self.operator = xcom_arg.operator
dag = xcom_pass_to_op()
new_value = 52
dr = dag.create_dagrun(
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
state=State.RUNNING,
conf={'value': new_value},
)
self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull(), new_value
def test_set_params_for_dag(self):
"""Test that dag param is correctly set when using dag decorator"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num
xcom_arg = return_num(value)
self.operator = xcom_arg.operator
dag = xcom_pass_to_op()
assert dag.params['value'] == self.VALUE