blob: 6c6b1cb9dddcb82dd653e4f7b49f59998252b461 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import inspect
import os
import shutil
import textwrap
import unittest
from datetime import datetime, timezone
from tempfile import NamedTemporaryFile, mkdtemp
from unittest import mock
from unittest.mock import patch
from freezegun import freeze_time
from sqlalchemy import func
from sqlalchemy.exc import OperationalError
import airflow.example_dags
from airflow import models
from airflow.exceptions import SerializationError
from airflow.models import DagBag, DagModel
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.dates import timezone as tz
from airflow.utils.session import create_session
from tests import cluster_policies
from tests.models import TEST_DAGS_FOLDER
from tests.test_utils import db
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
class TestDagBag(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.empty_dir = mkdtemp()
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.empty_dir)
def setUp(self) -> None:
db.clear_db_dags()
db.clear_db_serialized_dags()
def tearDown(self) -> None:
db.clear_db_dags()
db.clear_db_serialized_dags()
def test_get_existing_dag(self):
"""
Test that we're able to parse some example DAGs and retrieve them
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True)
some_expected_dag_ids = ["example_bash_operator", "example_branch_operator"]
for dag_id in some_expected_dag_ids:
dag = dagbag.get_dag(dag_id)
assert dag is not None
assert dag_id == dag.dag_id
assert dagbag.size() >= 7
def test_get_non_existing_dag(self):
"""
test that retrieving a non existing dag id returns None without crashing
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
non_existing_dag_id = "non_existing_dag_id"
assert dagbag.get_dag(non_existing_dag_id) is None
def test_dont_load_example(self):
"""
test that the example are not loaded
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
assert dagbag.size() == 0
def test_safe_mode_heuristic_match(self):
"""With safe mode enabled, a file matching the discovery heuristics
should be discovered.
"""
with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
f.write(b"# airflow")
f.write(b"# DAG")
f.flush()
with conf_vars({('core', 'dags_folder'): self.empty_dir}):
dagbag = models.DagBag(include_examples=False, safe_mode=True)
assert len(dagbag.dagbag_stats) == 1
assert dagbag.dagbag_stats[0].file == "/{}".format(os.path.basename(f.name))
def test_safe_mode_heuristic_mismatch(self):
"""With safe mode enabled, a file not matching the discovery heuristics
should not be discovered.
"""
with NamedTemporaryFile(dir=self.empty_dir, suffix=".py"):
with conf_vars({('core', 'dags_folder'): self.empty_dir}):
dagbag = models.DagBag(include_examples=False, safe_mode=True)
assert len(dagbag.dagbag_stats) == 0
def test_safe_mode_disabled(self):
"""With safe mode disabled, an empty python file should be discovered."""
with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
with conf_vars({('core', 'dags_folder'): self.empty_dir}):
dagbag = models.DagBag(include_examples=False, safe_mode=False)
assert len(dagbag.dagbag_stats) == 1
assert dagbag.dagbag_stats[0].file == "/{}".format(os.path.basename(f.name))
def test_process_file_that_contains_multi_bytes_char(self):
"""
test that we're able to parse file that contains multi-byte char
"""
f = NamedTemporaryFile()
f.write('\u3042'.encode()) # write multi-byte char (hiragana)
f.flush()
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
assert [] == dagbag.process_file(f.name)
def test_zip_skip_log(self):
"""
test the loading of a DAG from within a zip file that skips another file because
it doesn't have "airflow" and "DAG"
"""
with self.assertLogs() as cm:
test_zip_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")
dagbag = models.DagBag(dag_folder=test_zip_path, include_examples=False)
assert dagbag.has_logged
assert (
f'INFO:airflow.models.dagbag.DagBag:File {test_zip_path}:file_no_airflow_dag.py '
'assumed to contain no DAGs. Skipping.' in cm.output
)
def test_zip(self):
"""
test the loading of a DAG within a zip file that includes dependencies
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
assert dagbag.get_dag("test_zip_dag")
def test_process_file_cron_validity_check(self):
"""
test if an invalid cron expression
as schedule interval can be identified
"""
invalid_dag_files = ["test_invalid_cron.py", "test_zip_invalid_cron.zip"]
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
assert len(dagbag.import_errors) == 0
for file in invalid_dag_files:
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, file))
assert len(dagbag.import_errors) == len(invalid_dag_files)
assert len(dagbag.dags) == 0
@patch.object(DagModel, 'get_current')
def test_get_dag_without_refresh(self, mock_dagmodel):
"""
Test that, once a DAG is loaded, it doesn't get refreshed again if it
hasn't been expired.
"""
dag_id = 'example_bash_operator'
mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = None
mock_dagmodel.return_value.fileloc = 'foo'
class _TestDagBag(models.DagBag):
process_file_calls = 0
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if os.path.basename(filepath) == 'example_bash_operator.py':
_TestDagBag.process_file_calls += 1
super().process_file(filepath, only_if_updated, safe_mode)
dagbag = _TestDagBag(include_examples=True)
dagbag.process_file_calls
# Should not call process_file again, since it's already loaded during init.
assert 1 == dagbag.process_file_calls
assert dagbag.get_dag(dag_id) is not None
assert 1 == dagbag.process_file_calls
def test_get_dag_fileloc(self):
"""
Test that fileloc is correctly set when we load example DAGs,
specifically SubDAGs and packaged DAGs.
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
expected = {
'example_bash_operator': 'airflow/example_dags/example_bash_operator.py',
'example_subdag_operator': 'airflow/example_dags/example_subdag_operator.py',
'example_subdag_operator.section-1': 'airflow/example_dags/subdags/subdag.py',
'test_zip_dag': 'dags/test_zip.zip/test_zip.py',
}
for dag_id, path in expected.items():
dag = dagbag.get_dag(dag_id)
assert dag.fileloc.endswith(path)
@patch.object(DagModel, "get_current")
def test_refresh_py_dag(self, mock_dagmodel):
"""
Test that we can refresh an ordinary .py DAG
"""
example_dags_folder = airflow.example_dags.__path__[0]
dag_id = "example_bash_operator"
fileloc = os.path.realpath(os.path.join(example_dags_folder, "example_bash_operator.py"))
mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc)
mock_dagmodel.return_value.fileloc = fileloc
class _TestDagBag(DagBag):
process_file_calls = 0
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if filepath == fileloc:
_TestDagBag.process_file_calls += 1
return super().process_file(filepath, only_if_updated, safe_mode)
dagbag = _TestDagBag(dag_folder=self.empty_dir, include_examples=True)
assert 1 == dagbag.process_file_calls
dag = dagbag.get_dag(dag_id)
assert dag is not None
assert dag_id == dag.dag_id
assert 2 == dagbag.process_file_calls
@patch.object(DagModel, "get_current")
def test_refresh_packaged_dag(self, mock_dagmodel):
"""
Test that we can refresh a packaged DAG
"""
dag_id = "test_zip_dag"
fileloc = os.path.realpath(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip/test_zip.py"))
mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc)
mock_dagmodel.return_value.fileloc = fileloc
class _TestDagBag(DagBag):
process_file_calls = 0
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if filepath in fileloc:
_TestDagBag.process_file_calls += 1
return super().process_file(filepath, only_if_updated, safe_mode)
dagbag = _TestDagBag(dag_folder=os.path.realpath(TEST_DAGS_FOLDER), include_examples=False)
assert 1 == dagbag.process_file_calls
dag = dagbag.get_dag(dag_id)
assert dag is not None
assert dag_id == dag.dag_id
assert 2 == dagbag.process_file_calls
def process_dag(self, create_dag):
"""
Helper method to process a file generated from the input create_dag function.
"""
# write source to file
source = textwrap.dedent(''.join(inspect.getsource(create_dag).splitlines(True)[1:-1]))
f = NamedTemporaryFile()
f.write(source.encode('utf8'))
f.flush()
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
found_dags = dagbag.process_file(f.name)
return dagbag, found_dags, f.name
def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, should_be_found=True):
expected_dag_ids = list(map(lambda dag: dag.dag_id, expected_parent_dag.subdags))
expected_dag_ids.append(expected_parent_dag.dag_id)
actual_found_dag_ids = list(map(lambda dag: dag.dag_id, actual_found_dags))
for dag_id in expected_dag_ids:
actual_dagbag.log.info('validating %s' % dag_id)
assert (
dag_id in actual_found_dag_ids
) == should_be_found, 'dag "{}" should {}have been found after processing dag "{}"'.format(
dag_id,
'' if should_be_found else 'not ',
expected_parent_dag.dag_id,
)
assert (
dag_id in actual_dagbag.dags
) == should_be_found, 'dag "{}" should {}be in dagbag.dags after processing dag "{}"'.format(
dag_id,
'' if should_be_found else 'not ',
expected_parent_dag.dag_id,
)
def test_load_subdags(self):
# Define Dag to load
def standard_subdag():
import datetime # pylint: disable=redefined-outer-name,reimported
from airflow.models import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.subdag import SubDagOperator
dag_name = 'master'
default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
dag = DAG(dag_name, default_args=default_args)
# master:
# A -> opSubDag_0
# master.opsubdag_0:
# -> subdag_0.task
# A -> opSubDag_1
# master.opsubdag_1:
# -> subdag_1.task
with dag:
def subdag_0():
subdag_0 = DAG('master.op_subdag_0', default_args=default_args)
DummyOperator(task_id='subdag_0.task', dag=subdag_0)
return subdag_0
def subdag_1():
subdag_1 = DAG('master.op_subdag_1', default_args=default_args)
DummyOperator(task_id='subdag_1.task', dag=subdag_1)
return subdag_1
op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0())
op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1())
op_a = DummyOperator(task_id='A')
op_a.set_downstream(op_subdag_0)
op_a.set_downstream(op_subdag_1)
return dag
test_dag = standard_subdag()
# sanity check to make sure DAG.subdag is still functioning properly
assert len(test_dag.subdags) == 2
# Perform processing dag
dagbag, found_dags, _ = self.process_dag(standard_subdag)
# Validate correctness
# all dags from test_dag should be listed
self.validate_dags(test_dag, found_dags, dagbag)
# Define Dag to load
def nested_subdags():
import datetime # pylint: disable=redefined-outer-name,reimported
from airflow.models import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.subdag import SubDagOperator
dag_name = 'master'
default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
dag = DAG(dag_name, default_args=default_args)
# master:
# A -> op_subdag_0
# master.op_subdag_0:
# -> opSubDag_A
# master.op_subdag_0.opSubdag_A:
# -> subdag_a.task
# -> opSubdag_B
# master.op_subdag_0.opSubdag_B:
# -> subdag_b.task
# A -> op_subdag_1
# master.op_subdag_1:
# -> opSubdag_C
# master.op_subdag_1.opSubdag_C:
# -> subdag_c.task
# -> opSubDag_D
# master.op_subdag_1.opSubdag_D:
# -> subdag_d.task
with dag:
def subdag_a():
subdag_a = DAG('master.op_subdag_0.opSubdag_A', default_args=default_args)
DummyOperator(task_id='subdag_a.task', dag=subdag_a)
return subdag_a
def subdag_b():
subdag_b = DAG('master.op_subdag_0.opSubdag_B', default_args=default_args)
DummyOperator(task_id='subdag_b.task', dag=subdag_b)
return subdag_b
def subdag_c():
subdag_c = DAG('master.op_subdag_1.opSubdag_C', default_args=default_args)
DummyOperator(task_id='subdag_c.task', dag=subdag_c)
return subdag_c
def subdag_d():
subdag_d = DAG('master.op_subdag_1.opSubdag_D', default_args=default_args)
DummyOperator(task_id='subdag_d.task', dag=subdag_d)
return subdag_d
def subdag_0():
subdag_0 = DAG('master.op_subdag_0', default_args=default_args)
SubDagOperator(task_id='opSubdag_A', dag=subdag_0, subdag=subdag_a())
SubDagOperator(task_id='opSubdag_B', dag=subdag_0, subdag=subdag_b())
return subdag_0
def subdag_1():
subdag_1 = DAG('master.op_subdag_1', default_args=default_args)
SubDagOperator(task_id='opSubdag_C', dag=subdag_1, subdag=subdag_c())
SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_d())
return subdag_1
op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0())
op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1())
op_a = DummyOperator(task_id='A')
op_a.set_downstream(op_subdag_0)
op_a.set_downstream(op_subdag_1)
return dag
test_dag = nested_subdags()
# sanity check to make sure DAG.subdag is still functioning properly
assert len(test_dag.subdags) == 6
# Perform processing dag
dagbag, found_dags, _ = self.process_dag(nested_subdags)
# Validate correctness
# all dags from test_dag should be listed
self.validate_dags(test_dag, found_dags, dagbag)
def test_skip_cycle_dags(self):
"""
Don't crash when loading an invalid (contains a cycle) DAG file.
Don't load the dag into the DagBag either
"""
# Define Dag to load
def basic_cycle():
import datetime # pylint: disable=redefined-outer-name,reimported
from airflow.models import DAG
from airflow.operators.dummy import DummyOperator
dag_name = 'cycle_dag'
default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
dag = DAG(dag_name, default_args=default_args)
# A -> A
with dag:
op_a = DummyOperator(task_id='A')
op_a.set_downstream(op_a)
return dag
test_dag = basic_cycle()
# sanity check to make sure DAG.subdag is still functioning properly
assert len(test_dag.subdags) == 0
# Perform processing dag
dagbag, found_dags, file_path = self.process_dag(basic_cycle)
# #Validate correctness
# None of the dags should be found
self.validate_dags(test_dag, found_dags, dagbag, should_be_found=False)
assert file_path in dagbag.import_errors
# Define Dag to load
def nested_subdag_cycle():
import datetime # pylint: disable=redefined-outer-name,reimported
from airflow.models import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.subdag import SubDagOperator
dag_name = 'nested_cycle'
default_args = {'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1)}
dag = DAG(dag_name, default_args=default_args)
# cycle:
# A -> op_subdag_0
# cycle.op_subdag_0:
# -> opSubDag_A
# cycle.op_subdag_0.opSubdag_A:
# -> subdag_a.task
# -> opSubdag_B
# cycle.op_subdag_0.opSubdag_B:
# -> subdag_b.task
# A -> op_subdag_1
# cycle.op_subdag_1:
# -> opSubdag_C
# cycle.op_subdag_1.opSubdag_C:
# -> subdag_c.task -> subdag_c.task >Invalid Loop<
# -> opSubDag_D
# cycle.op_subdag_1.opSubdag_D:
# -> subdag_d.task
with dag:
def subdag_a():
subdag_a = DAG('nested_cycle.op_subdag_0.opSubdag_A', default_args=default_args)
DummyOperator(task_id='subdag_a.task', dag=subdag_a)
return subdag_a
def subdag_b():
subdag_b = DAG('nested_cycle.op_subdag_0.opSubdag_B', default_args=default_args)
DummyOperator(task_id='subdag_b.task', dag=subdag_b)
return subdag_b
def subdag_c():
subdag_c = DAG('nested_cycle.op_subdag_1.opSubdag_C', default_args=default_args)
op_subdag_c_task = DummyOperator(task_id='subdag_c.task', dag=subdag_c)
# introduce a loop in opSubdag_C
op_subdag_c_task.set_downstream(op_subdag_c_task)
return subdag_c
def subdag_d():
subdag_d = DAG('nested_cycle.op_subdag_1.opSubdag_D', default_args=default_args)
DummyOperator(task_id='subdag_d.task', dag=subdag_d)
return subdag_d
def subdag_0():
subdag_0 = DAG('nested_cycle.op_subdag_0', default_args=default_args)
SubDagOperator(task_id='opSubdag_A', dag=subdag_0, subdag=subdag_a())
SubDagOperator(task_id='opSubdag_B', dag=subdag_0, subdag=subdag_b())
return subdag_0
def subdag_1():
subdag_1 = DAG('nested_cycle.op_subdag_1', default_args=default_args)
SubDagOperator(task_id='opSubdag_C', dag=subdag_1, subdag=subdag_c())
SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_d())
return subdag_1
op_subdag_0 = SubDagOperator(task_id='op_subdag_0', dag=dag, subdag=subdag_0())
op_subdag_1 = SubDagOperator(task_id='op_subdag_1', dag=dag, subdag=subdag_1())
op_a = DummyOperator(task_id='A')
op_a.set_downstream(op_subdag_0)
op_a.set_downstream(op_subdag_1)
return dag
test_dag = nested_subdag_cycle()
# sanity check to make sure DAG.subdag is still functioning properly
assert len(test_dag.subdags) == 6
# Perform processing dag
dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle)
# Validate correctness
# None of the dags should be found
self.validate_dags(test_dag, found_dags, dagbag, should_be_found=False)
assert file_path in dagbag.import_errors
def test_process_file_with_none(self):
"""
test that process_file can handle Nones
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
assert [] == dagbag.process_file(None)
def test_deactivate_unknown_dags(self):
"""
Test that dag_ids not passed into deactivate_unknown_dags
are deactivated when function is invoked
"""
dagbag = DagBag(include_examples=True)
dag_id = "test_deactivate_unknown_dags"
expected_active_dags = dagbag.dags.keys()
model_before = DagModel(dag_id=dag_id, is_active=True)
with create_session() as session:
session.merge(model_before)
models.DAG.deactivate_unknown_dags(expected_active_dags)
after_model = DagModel.get_dagmodel(dag_id)
assert model_before.is_active
assert not after_model.is_active
# clean up
with create_session() as session:
session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete()
def test_serialized_dags_are_written_to_db_on_sync(self):
"""
Test that when dagbag.sync_to_db is called the DAGs are Serialized and written to DB
even when dagbag.read_dags_from_db is False
"""
with create_session() as session:
serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar()
assert serialized_dags_count == 0
dagbag = DagBag(
dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"),
include_examples=False,
)
dagbag.sync_to_db()
assert not dagbag.read_dags_from_db
new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar()
assert new_serialized_dags_count == 1
@patch("airflow.models.serialized_dag.SerializedDagModel.write_dag")
def test_serialized_dag_errors_are_import_errors(self, mock_serialize):
"""
Test that errors serializing a DAG are recorded as import_errors in the DB
"""
mock_serialize.side_effect = SerializationError
with create_session() as session:
path = os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py")
dagbag = DagBag(
dag_folder=path,
include_examples=False,
)
assert dagbag.import_errors == {}
dagbag.sync_to_db(session=session)
assert path in dagbag.import_errors
err = dagbag.import_errors[path]
assert "SerializationError" in err
session.rollback()
@patch("airflow.models.dagbag.DagBag.collect_dags")
@patch("airflow.models.serialized_dag.SerializedDagModel.write_dag")
@patch("airflow.models.dag.DAG.bulk_write_to_db")
def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, mock_collect_dags):
"""Test that dagbag.sync_to_db is retried on OperationalError"""
dagbag = DagBag("/dev/null")
mock_dag = mock.MagicMock(spec=models.DAG)
mock_dag.is_subdag = False
dagbag.dags['mock_dag'] = mock_dag
op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)
# Mock error for the first 2 tries and a successful third try
side_effect = [op_error, op_error, mock.ANY]
mock_bulk_write_to_db.side_effect = side_effect
mock_session = mock.MagicMock()
dagbag.sync_to_db(session=mock_session)
# Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully
mock_bulk_write_to_db.assert_has_calls(
[
mock.call(mock.ANY, session=mock.ANY),
mock.call(mock.ANY, session=mock.ANY),
mock.call(mock.ANY, session=mock.ANY),
]
)
# Assert that rollback is called twice (i.e. whenever OperationalError occurs)
mock_session.rollback.assert_has_calls([mock.call(), mock.call()])
# Check that 'SerializedDagModel.write_dag' is also called
# Only called once since the other two times the 'DAG.bulk_write_to_db' error'd
# and the session was roll-backed before even reaching 'SerializedDagModel.write_dag'
mock_s10n_write_dag.assert_has_calls(
[
mock.call(mock_dag, min_update_interval=mock.ANY, session=mock_session),
]
)
@patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5)
@patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5)
def test_get_dag_with_dag_serialization(self):
"""
Test that Serialized DAG is updated in DagBag when it is updated in
Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed.
"""
with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)):
example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator")
SerializedDagModel.write_dag(dag=example_bash_op_dag)
dag_bag = DagBag(read_dags_from_db=True)
ser_dag_1 = dag_bag.get_dag("example_bash_operator")
ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
assert example_bash_op_dag.tags == ser_dag_1.tags
assert ser_dag_1_update_time == tz.datetime(2020, 1, 5, 0, 0, 0)
# Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG
# from DB
with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)):
with assert_queries_count(0):
assert dag_bag.get_dag("example_bash_operator").tags == ["example", "example2"]
# Make a change in the DAG and write Serialized DAG to the DB
with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)):
example_bash_op_dag.tags += ["new_tag"]
SerializedDagModel.write_dag(dag=example_bash_op_dag)
# Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
# fetches the Serialized DAG from DB
with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 8)):
with assert_queries_count(2):
updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator")
updated_ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
assert set(updated_ser_dag_1.tags) == {"example", "example2", "new_tag"}
assert updated_ser_dag_1_update_time > ser_dag_1_update_time
def test_collect_dags_from_db(self):
"""DAGs are collected from Database"""
example_dags_folder = airflow.example_dags.__path__[0]
dagbag = DagBag(example_dags_folder)
example_dags = dagbag.dags
for dag in example_dags.values():
SerializedDagModel.write_dag(dag)
new_dagbag = DagBag(read_dags_from_db=True)
assert len(new_dagbag.dags) == 0
new_dagbag.collect_dags_from_db()
new_dags = new_dagbag.dags
assert len(example_dags) == len(new_dags)
for dag_id, dag in example_dags.items():
serialized_dag = new_dags[dag_id]
assert serialized_dag.dag_id == dag.dag_id
assert set(serialized_dag.task_dict) == set(dag.task_dict)
@patch("airflow.settings.task_policy", cluster_policies.cluster_policy)
def test_task_cluster_policy_violation(self):
"""
test that file processing results in import error when task does not
obey cluster policy.
"""
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py")
dagbag = DagBag(dag_folder=dag_file, include_smart_sensor=False, include_examples=False)
assert set() == set(dagbag.dag_ids)
expected_import_errors = {
dag_file: (
f"""DAG policy violation (DAG ID: test_missing_owner, Path: {dag_file}):\n"""
"""Notices:\n"""
""" * Task must have non-None non-default owner. Current value: airflow"""
)
}
assert expected_import_errors == dagbag.import_errors
@patch("airflow.settings.task_policy", cluster_policies.cluster_policy)
def test_task_cluster_policy_obeyed(self):
"""
test that dag successfully imported without import errors when tasks
obey cluster policy.
"""
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_with_non_default_owner.py")
dagbag = DagBag(dag_folder=dag_file, include_examples=False, include_smart_sensor=False)
assert {"test_with_non_default_owner"} == set(dagbag.dag_ids)
assert {} == dagbag.import_errors
@patch("airflow.settings.dag_policy", cluster_policies.dag_policy)
def test_dag_cluster_policy_obeyed(self):
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_dag_with_no_tags.py")
dagbag = DagBag(dag_folder=dag_file, include_examples=False, include_smart_sensor=False)
assert len(dagbag.dag_ids) == 0
assert "has no tags" in dagbag.import_errors[dag_file]