blob: 7e95d12bb55cf9b823d1cf7d8a84d7b36aed8168 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import contextlib
import inspect
import logging
import os
import pathlib
import re
import sys
import textwrap
import warnings
import zipfile
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from unittest import mock
from unittest.mock import patch
import pytest
from sqlalchemy import select
from airflow import settings
from airflow.dag_processing.dagbag import DagBag, _capture_with_reraise, _validate_executor_fields
from airflow.exceptions import UnknownExecutorException
from airflow.executors.executor_loader import ExecutorLoader
from airflow.models.dag import DagModel
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk import DAG, BaseOperator
from tests_common.pytest_plugin import AIRFLOW_ROOT_PATH
from tests_common.test_utils import db
from tests_common.test_utils.config import conf_vars
from unit import cluster_policies
from unit.models import TEST_DAGS_FOLDER
pytestmark = pytest.mark.db_test
example_dags_folder = AIRFLOW_ROOT_PATH / "airflow-core" / "src" / "airflow" / "example_dags" / "standard"
PY311 = sys.version_info >= (3, 11)
PY313 = sys.version_info >= (3, 13)
# Include the words "airflow" and "dag" in the file contents,
# tricking airflow into thinking these
# files contain a DAG (otherwise Airflow will skip them)
INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = "def something():\n return airflow_DAG\nsomething()"
def db_clean_up():
db.clear_db_dags()
db.clear_db_runs()
db.clear_db_serialized_dags()
db.clear_dag_specific_permissions()
class TestValidateExecutorFields:
"""Comprehensive tests for _validate_executor_fields function."""
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_multi_team_disabled_ignores_bundle_name(self, mock_lookup):
"""Test that when multi_team is disabled, bundle_name is ignored and no team lookup occurs."""
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="t1", executor="test.executor")
# multi_team disabled by default, no need to add conf_vars
_validate_executor_fields(dag, bundle_name="some_bundle")
# Should call ExecutorLoader without team_name (defaults to None)
mock_lookup.assert_called_once_with("test.executor", team_name=None)
@patch("airflow.dag_processing.bundles.manager.DagBundlesManager")
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_multi_team_enabled_bundle_exists_with_team(self, mock_lookup, mock_manager_class):
"""Test successful team lookup when bundle exists and has team_name."""
# Setup mock bundle manager
mock_bundle_config = mock.MagicMock()
mock_bundle_config.team_name = "test_team"
mock_manager = mock_manager_class.return_value
mock_manager._bundle_config = {"test_bundle": mock_bundle_config}
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="t1", executor="team.executor")
with conf_vars({("core", "multi_team"): "True"}):
_validate_executor_fields(dag, bundle_name="test_bundle")
# Should call ExecutorLoader with team from bundle config
mock_lookup.assert_called_once_with("team.executor", team_name="test_team")
@patch("airflow.dag_processing.bundles.manager.DagBundlesManager")
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_multi_team_enabled_bundle_exists_no_team(self, mock_lookup, mock_manager_class):
"""Test when bundle exists but has no team_name (None or empty)."""
mock_bundle_config = mock.MagicMock()
mock_bundle_config.team_name = None # No team associated
mock_manager = mock_manager_class.return_value
mock_manager._bundle_config = {"test_bundle": mock_bundle_config}
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="t1", executor="test.executor")
with conf_vars({("core", "multi_team"): "True"}):
_validate_executor_fields(dag, bundle_name="test_bundle")
mock_lookup.assert_called_once_with("test.executor", team_name=None)
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_multiple_tasks_with_executors(self, mock_lookup):
"""Test that all tasks with executors are validated."""
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="t1", executor="executor1")
BaseOperator(task_id="t2", executor="executor2")
BaseOperator(task_id="t3") # No executor, should be skipped
with conf_vars({("core", "multi_team"): "True"}):
_validate_executor_fields(dag)
# Should be called for each task with executor
assert mock_lookup.call_count == 2
mock_lookup.assert_any_call("executor1", team_name=None)
mock_lookup.assert_any_call("executor2", team_name=None)
@patch("airflow.dag_processing.bundles.manager.DagBundlesManager")
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_executor_validation_failure_with_team(self, mock_lookup, mock_manager_class):
"""Test executor validation failure when team is associated (team-specific error)."""
mock_bundle_config = mock.MagicMock()
mock_bundle_config.team_name = "test_team"
mock_manager = mock_manager_class.return_value
mock_manager._bundle_config = {"test_bundle": mock_bundle_config}
# ExecutorLoader raises exception
mock_lookup.side_effect = UnknownExecutorException("Executor not found")
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="task1", executor="invalid.executor")
with conf_vars({("core", "multi_team"): "True"}):
with pytest.raises(
UnknownExecutorException,
match=re.escape(
"Task 'task1' specifies executor 'invalid.executor', which is not available "
"for team 'test_team' (the team associated with DAG 'test-dag') or as a global executor. "
"Make sure 'invalid.executor' is configured for team 'test_team' or globally in your "
"[core] executors configuration, or update the task's executor to use one of the "
"configured executors for team 'test_team' or available global executors."
),
):
_validate_executor_fields(dag, bundle_name="test_bundle")
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_executor_validation_failure_no_team(self, mock_lookup):
"""Test executor validation failure when no team is associated (generic error)."""
mock_lookup.side_effect = UnknownExecutorException("Executor not found")
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="task1", executor="invalid.executor")
with conf_vars({("core", "multi_team"): "True"}):
with pytest.raises(
UnknownExecutorException,
match=re.escape(
"Task 'task1' specifies executor 'invalid.executor', which is not available. "
"Make sure it is listed in your [core] executors configuration, or update the task's "
"executor to use one of the configured executors."
),
):
_validate_executor_fields(dag) # No bundle_name
@patch("airflow.dag_processing.bundles.manager.DagBundlesManager")
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_global_executor_fallback_success(self, mock_lookup, mock_manager_class):
"""Test that team-specific executor failure falls back to global executor successfully."""
mock_bundle_config = mock.MagicMock()
mock_bundle_config.team_name = "test_team"
mock_manager = mock_manager_class.return_value
mock_manager._bundle_config = {"test_bundle": mock_bundle_config}
# First call (team-specific) fails, second call (global) succeeds
mock_lookup.side_effect = [UnknownExecutorException("Team executor not found"), None]
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="task1", executor="global.executor")
with conf_vars({("core", "multi_team"): "True"}):
# Should not raise exception due to global fallback
_validate_executor_fields(dag, bundle_name="test_bundle")
# Should call lookup twice: first for team, then for global
assert mock_lookup.call_count == 2
mock_lookup.assert_any_call("global.executor", team_name="test_team")
mock_lookup.assert_any_call("global.executor", team_name=None)
@patch("airflow.dag_processing.bundles.manager.DagBundlesManager")
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_global_executor_fallback_failure(self, mock_lookup, mock_manager_class):
"""Test that when both team-specific and global executors fail, appropriate error is raised."""
mock_bundle_config = mock.MagicMock()
mock_bundle_config.team_name = "test_team"
mock_manager = mock_manager_class.return_value
mock_manager._bundle_config = {"test_bundle": mock_bundle_config}
# Both calls fail
mock_lookup.side_effect = UnknownExecutorException("Executor not found")
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="task1", executor="unknown.executor")
with conf_vars({("core", "multi_team"): "True"}):
with pytest.raises(
UnknownExecutorException,
match=re.escape(
"Task 'task1' specifies executor 'unknown.executor', which is not available "
"for team 'test_team' (the team associated with DAG 'test-dag') or as a global executor. "
"Make sure 'unknown.executor' is configured for team 'test_team' or globally in your "
"[core] executors configuration, or update the task's executor to use one of the "
"configured executors for team 'test_team' or available global executors."
),
):
_validate_executor_fields(dag, bundle_name="test_bundle")
# Should call lookup twice: first for team, then for global fallback
assert mock_lookup.call_count == 2
mock_lookup.assert_any_call("unknown.executor", team_name="test_team")
mock_lookup.assert_any_call("unknown.executor", team_name=None)
@patch("airflow.dag_processing.bundles.manager.DagBundlesManager")
@patch.object(ExecutorLoader, "lookup_executor_name_by_str")
def test_team_specific_executor_success_no_fallback(self, mock_lookup, mock_manager_class):
"""Test that when team-specific executor succeeds, global fallback is not attempted."""
mock_bundle_config = mock.MagicMock()
mock_bundle_config.team_name = "test_team"
mock_manager = mock_manager_class.return_value
mock_manager._bundle_config = {"test_bundle": mock_bundle_config}
# First call (team-specific) succeeds
mock_lookup.return_value = None
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="task1", executor="team.executor")
with conf_vars({("core", "multi_team"): "True"}):
_validate_executor_fields(dag, bundle_name="test_bundle")
# Should only call lookup once for team-specific executor
mock_lookup.assert_called_once_with("team.executor", team_name="test_team")
def test_validate_executor_field_executor_not_configured():
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="t1", executor="test.custom.executor")
with pytest.raises(
UnknownExecutorException,
match=re.escape(
"Task 't1' specifies executor 'test.custom.executor', which is not available. "
"Make sure it is listed in your [core] executors configuration, or update the task's "
"executor to use one of the configured executors."
),
):
_validate_executor_fields(dag)
def test_validate_executor_field():
with DAG("test-dag", schedule=None) as dag:
BaseOperator(task_id="t1", executor="test.custom.executor")
with patch.object(ExecutorLoader, "lookup_executor_name_by_str"):
_validate_executor_fields(dag)
class TestDagBag:
def setup_class(self):
db_clean_up()
def teardown_class(self):
db_clean_up()
def test_dagbag_with_bundle_name(self, tmp_path):
"""Test that DagBag constructor accepts and stores bundle_name parameter."""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False, bundle_name="test_bundle")
assert dagbag.bundle_name == "test_bundle"
# Test with None (default)
dagbag2 = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert dagbag2.bundle_name is None
def test_timeout_context_manager_raises_exception(self):
"""Test that the timeout context manager raises AirflowTaskTimeout when time limit is exceeded."""
import time
from airflow.dag_processing.dagbag import timeout
from airflow.exceptions import AirflowTaskTimeout
with pytest.raises(AirflowTaskTimeout):
with timeout(1, "Test timeout"):
time.sleep(2)
def test_get_existing_dag(self, tmp_path):
"""
Test that we're able to parse some example DAGs and retrieve them
"""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=True, bundle_name="test_bundle")
some_expected_dag_ids = ["example_bash_operator", "example_branch_operator"]
for dag_id in some_expected_dag_ids:
dag = dagbag.get_dag(dag_id)
assert dag is not None
assert dag_id == dag.dag_id
assert dagbag.size() >= 7
def test_get_non_existing_dag(self, tmp_path):
"""
test that retrieving a non existing dag id returns None without crashing
"""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
non_existing_dag_id = "non_existing_dag_id"
assert dagbag.get_dag(non_existing_dag_id) is None
def test_serialized_dag_not_existing_doesnt_raise(self, tmp_path, session):
"""
test that retrieving a non existing dag id returns None without crashing
"""
non_existing_dag_id = "non_existing_dag_id"
assert session.scalar(select(True).where(SerializedDagModel.dag_id == non_existing_dag_id)) is None
def test_dont_load_example(self, tmp_path):
"""
test that the example are not loaded
"""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert dagbag.size() == 0
def test_safe_mode_heuristic_match(self, tmp_path):
"""
With safe mode enabled, a file matching the discovery heuristics
should be discovered.
"""
path = tmp_path / "testfile.py"
path.write_text("# airflow\n# DAG")
with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
dagbag = DagBag(include_examples=False, safe_mode=True)
assert len(dagbag.dagbag_stats) == 1
assert dagbag.dagbag_stats[0].file == f"/{path.name}"
def test_safe_mode_heuristic_mismatch(self, tmp_path):
"""
With safe mode enabled, a file not matching the discovery heuristics
should not be discovered.
"""
path = tmp_path / "testfile.py"
path.write_text("")
with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
dagbag = DagBag(include_examples=False, safe_mode=True)
assert len(dagbag.dagbag_stats) == 0
def test_safe_mode_disabled(self, tmp_path):
"""With safe mode disabled, an empty python file should be discovered."""
path = tmp_path / "testfile.py"
path.write_text("")
with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
dagbag = DagBag(include_examples=False, safe_mode=False)
assert len(dagbag.dagbag_stats) == 1
assert dagbag.dagbag_stats[0].file == f"/{path.name}"
def test_process_file_that_contains_multi_bytes_char(self, tmp_path):
"""
test that we're able to parse file that contains multi-byte char
"""
path = tmp_path / "testfile"
path.write_text("\u3042") # write multi-byte char (hiragana)
dagbag = DagBag(dag_folder=os.fspath(path.parent), include_examples=False)
assert dagbag.process_file(os.fspath(path)) == []
def test_process_file_duplicated_dag_id(self, tmp_path):
"""Loading a DAG with ID that already existed in a DAG bag should result in an import error."""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
def create_dag():
from airflow.sdk import dag
@dag(schedule=None, default_args={"owner": "owner1"})
def my_flow():
pass
my_dag = my_flow() # noqa: F841
source_lines = [line[12:] for line in inspect.getsource(create_dag).splitlines(keepends=True)[1:]]
path1 = tmp_path / "testfile1"
path2 = tmp_path / "testfile2"
path1.write_text("".join(source_lines))
path2.write_text("".join(source_lines))
found_1 = dagbag.process_file(os.fspath(path1))
assert len(found_1) == 1
assert found_1[0].dag_id == "my_flow"
assert dagbag.import_errors == {}
dags_in_bag = dagbag.dags
found_2 = dagbag.process_file(os.fspath(path2))
assert len(found_2) == 0
assert dagbag.import_errors[os.fspath(path2)].startswith(
"AirflowDagDuplicatedIdException: Ignoring DAG"
)
assert dagbag.dags == dags_in_bag # Should not change.
def test_zip_skip_log(self, caplog, test_zip_path):
"""
test the loading of a DAG from within a zip file that skips another file because
it doesn't have "airflow" and "DAG"
"""
caplog.set_level(logging.INFO)
dagbag = DagBag(dag_folder=test_zip_path, include_examples=False)
assert dagbag.has_logged
assert (
f"File {test_zip_path}:file_no_airflow_dag.py "
"assumed to contain no DAGs. Skipping." in caplog.text
)
def test_zip(self, tmp_path, test_zip_path):
"""
test the loading of a DAG within a zip file that includes dependencies
"""
syspath_before = deepcopy(sys.path)
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
dagbag.process_file(test_zip_path)
assert dagbag.get_dag("test_zip_dag")
assert sys.path == syspath_before # sys.path doesn't change
assert not dagbag.import_errors
@patch("airflow.dag_processing.dagbag.timeout")
@patch("airflow.dag_processing.dagbag.settings.get_dagbag_import_timeout")
def test_process_dag_file_without_timeout(
self, mocked_get_dagbag_import_timeout, mocked_timeout, tmp_path
):
"""
Test dag file parsing without timeout
"""
mocked_get_dagbag_import_timeout.return_value = 0
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_sensor.py"))
mocked_timeout.assert_not_called()
mocked_get_dagbag_import_timeout.return_value = -1
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_sensor.py"))
mocked_timeout.assert_not_called()
@patch("airflow.dag_processing.dagbag.timeout")
@patch("airflow.dag_processing.dagbag.settings.get_dagbag_import_timeout")
def test_process_dag_file_with_non_default_timeout(
self, mocked_get_dagbag_import_timeout, mocked_timeout, tmp_path
):
"""
Test customized dag file parsing timeout
"""
timeout_value = 100
mocked_get_dagbag_import_timeout.return_value = timeout_value
# ensure the test value is not equal to the default value
assert timeout_value != settings.conf.getfloat("core", "DAGBAG_IMPORT_TIMEOUT")
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_sensor.py"))
mocked_timeout.assert_called_once_with(timeout_value, error_message=mock.ANY)
@patch("airflow.dag_processing.dagbag.settings.get_dagbag_import_timeout")
def test_check_value_type_from_get_dagbag_import_timeout(
self, mocked_get_dagbag_import_timeout, tmp_path
):
"""
Test correctness of value from get_dagbag_import_timeout
"""
mocked_get_dagbag_import_timeout.return_value = "1"
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
with pytest.raises(
TypeError, match=r"Value \(1\) from get_dagbag_import_timeout must be int or float"
):
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_sensor.py"))
@pytest.fixture
def invalid_cron_dag(self) -> str:
return os.path.join(TEST_DAGS_FOLDER, "test_invalid_cron.py")
@pytest.fixture
def invalid_cron_zipped_dag(self, invalid_cron_dag: str, tmp_path: pathlib.Path) -> str:
zipped = tmp_path / "test_zip_invalid_cron.zip"
with zipfile.ZipFile(zipped, "w") as zf:
zf.write(invalid_cron_dag, os.path.basename(invalid_cron_dag))
return os.fspath(zipped)
@pytest.mark.parametrize("invalid_dag_name", ["invalid_cron_dag", "invalid_cron_zipped_dag"])
def test_process_file_cron_validity_check(
self, request: pytest.FixtureRequest, invalid_dag_name: str, tmp_path
):
"""Test if an invalid cron expression as schedule interval can be identified"""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert len(dagbag.import_errors) == 0
dagbag.process_file(request.getfixturevalue(invalid_dag_name))
assert len(dagbag.import_errors) == 1
assert len(dagbag.dags) == 0
def test_process_file_invalid_param_check(self, tmp_path):
"""
test if an invalid param in the dags can be identified
"""
invalid_dag_files = [
"test_invalid_param.py",
"test_invalid_param2.py",
"test_invalid_param3.py",
"test_invalid_param4.py",
]
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert len(dagbag.import_errors) == 0
for file in invalid_dag_files:
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, file))
assert len(dagbag.import_errors) == len(invalid_dag_files)
assert len(dagbag.dags) == 0
def test_process_file_valid_param_check(self, tmp_path):
"""
test if valid params in the dags param can be validated (positive test)
"""
valid_dag_files = [
"test_valid_param.py",
"test_valid_param2.py",
]
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert len(dagbag.import_errors) == 0
for file in valid_dag_files:
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, file))
assert len(dagbag.import_errors) == 0
assert len(dagbag.dags) == len(valid_dag_files)
@patch.object(DagModel, "get_current")
def test_get_dag_without_refresh(self, mock_dagmodel):
"""
Test that, once a DAG is loaded, it doesn't get refreshed again if it
hasn't been expired.
"""
dag_id = "example_bash_operator"
mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = None
mock_dagmodel.return_value.fileloc = "foo"
class _TestDagBag(DagBag):
process_file_calls = 0
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if os.path.basename(filepath) == "example_bash_operator.py":
_TestDagBag.process_file_calls += 1
super().process_file(filepath, only_if_updated, safe_mode)
dagbag = _TestDagBag(include_examples=True)
dagbag.process_file_calls
# Should not call process_file again, since it's already loaded during init.
assert dagbag.process_file_calls == 1
assert dagbag.get_dag(dag_id) is not None
assert dagbag.process_file_calls == 1
@pytest.mark.parametrize(
("file_to_load", "expected"),
(
pytest.param(
pathlib.Path(example_dags_folder) / "example_bash_operator.py",
{
"example_bash_operator": f"{example_dags_folder.relative_to(AIRFLOW_ROOT_PATH) / 'example_bash_operator.py'}"
},
id="example_bash_operator",
),
),
)
def test_get_dag_registration(self, file_to_load, expected):
pytest.importorskip("system.standard")
dagbag = DagBag(dag_folder=os.devnull, include_examples=False)
dagbag.process_file(os.fspath(file_to_load))
for dag_id, path in expected.items():
dag = dagbag.get_dag(dag_id)
assert dag, f"{dag_id} was bagged"
assert dag.fileloc.endswith(path)
@pytest.mark.parametrize(
("expected"),
(
pytest.param(
{
"test_zip_dag": "test_zip.zip/test_zip.py",
"test_zip_autoregister": "test_zip.zip/test_zip.py",
},
id="test_zip.zip",
),
),
)
def test_get_zip_dag_registration(self, test_zip_path, expected):
dagbag = DagBag(dag_folder=os.devnull, include_examples=False)
dagbag.process_file(test_zip_path)
for dag_id, path in expected.items():
dag = dagbag.get_dag(dag_id)
assert dag, f"{dag_id} was bagged"
assert dag.fileloc.endswith(f"{pathlib.Path(test_zip_path).parent}/{path}")
def test_dag_registration_with_failure(self):
dagbag = DagBag(dag_folder=os.devnull, include_examples=False)
found = dagbag.process_file(str(TEST_DAGS_FOLDER / "test_invalid_dup_task.py"))
assert found == []
@pytest.fixture
def zip_with_valid_dag_and_dup_tasks(self, tmp_path: pathlib.Path) -> str:
failing_dag_file = TEST_DAGS_FOLDER / "test_invalid_dup_task.py"
working_dag_file = TEST_DAGS_FOLDER / "test_example_bash_operator.py"
zipped = tmp_path / "test_zip_invalid_dup_task.zip"
with zipfile.ZipFile(zipped, "w") as zf:
zf.write(failing_dag_file, failing_dag_file.name)
zf.write(working_dag_file, working_dag_file.name)
return os.fspath(zipped)
def test_dag_registration_with_failure_zipped(self, zip_with_valid_dag_and_dup_tasks):
dagbag = DagBag(dag_folder=os.devnull, include_examples=False)
found = dagbag.process_file(zip_with_valid_dag_and_dup_tasks)
assert len(found) == 1
assert [dag.dag_id for dag in found] == ["test_example_bash_operator"]
@patch.object(DagModel, "get_current")
def test_refresh_py_dag(self, mock_dagmodel, tmp_path):
"""
Test that we can refresh an ordinary .py DAG
"""
dag_id = "example_bash_operator"
fileloc = str(example_dags_folder / "example_bash_operator.py")
mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc)
mock_dagmodel.return_value.fileloc = fileloc
class _TestDagBag(DagBag):
process_file_calls = 0
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if filepath == fileloc:
_TestDagBag.process_file_calls += 1
return super().process_file(filepath, only_if_updated, safe_mode)
dagbag = _TestDagBag(dag_folder=os.fspath(tmp_path), include_examples=True)
assert dagbag.process_file_calls == 1
dag = dagbag.get_dag(dag_id)
assert dag is not None
assert dag_id == dag.dag_id
assert dagbag.process_file_calls == 2
@patch.object(DagModel, "get_current")
def test_refresh_packaged_dag(self, mock_dagmodel, test_zip_path):
"""
Test that we can refresh a packaged DAG
"""
dag_id = "test_zip_dag"
fileloc = os.path.realpath(os.path.join(test_zip_path, "test_zip.py"))
mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc)
mock_dagmodel.return_value.fileloc = fileloc
class _TestDagBag(DagBag):
process_file_calls = 0
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if filepath in fileloc:
_TestDagBag.process_file_calls += 1
return super().process_file(filepath, only_if_updated, safe_mode)
dagbag = _TestDagBag(dag_folder=os.path.realpath(test_zip_path), include_examples=False)
assert dagbag.process_file_calls == 1
dag = dagbag.get_dag(dag_id)
assert dag is not None
assert dag_id == dag.dag_id
assert dagbag.process_file_calls == 2
def process_dag(self, create_dag, tmp_path):
"""
Helper method to process a file generated from the input create_dag function.
"""
# write source to file
source = textwrap.dedent("".join(inspect.getsource(create_dag).splitlines(True)[1:-1]))
path = tmp_path / "testfile"
path.write_text(source)
dagbag = DagBag(dag_folder=os.fspath(path.parent), include_examples=False)
found_dags = dagbag.process_file(os.fspath(path))
return dagbag, found_dags, os.fspath(path)
def validate_dags(self, expected_dag, actual_found_dags, actual_dagbag, should_be_found=True):
actual_found_dag_ids = [dag.dag_id for dag in actual_found_dags]
dag_id = expected_dag.dag_id
actual_dagbag.log.info("validating %s", dag_id)
assert (dag_id in actual_found_dag_ids) == should_be_found, (
f'dag "{dag_id}" should {"" if should_be_found else "not "}'
f'have been found after processing dag "{expected_dag.dag_id}"'
)
assert (dag_id in actual_dagbag.dags) == should_be_found, (
f'dag "{dag_id}" should {"" if should_be_found else "not "}'
f'be in dagbag.dags after processing dag "{expected_dag.dag_id}"'
)
def test_skip_cycle_dags(self, tmp_path):
"""
Don't crash when loading an invalid (contains a cycle) DAG file.
Don't load the dag into the DagBag either
"""
# Define Dag to load
def basic_cycle():
import datetime
from airflow.models.dag import DAG
from airflow.providers.standard.operators.empty import EmptyOperator
dag_name = "cycle_dag"
default_args = {"owner": "owner1", "start_date": datetime.datetime(2016, 1, 1)}
dag = DAG(dag_name, schedule=timedelta(days=1), default_args=default_args)
# A -> A
with dag:
op_a = EmptyOperator(task_id="A")
op_a.set_downstream(op_a)
return dag
test_dag = basic_cycle()
# Perform processing dag
dagbag, found_dags, file_path = self.process_dag(basic_cycle, tmp_path)
# #Validate correctness
# None of the dags should be found
self.validate_dags(test_dag, found_dags, dagbag, should_be_found=False)
assert file_path in dagbag.import_errors
def test_process_file_with_none(self, tmp_path):
"""
test that process_file can handle Nones
"""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert dagbag.process_file(None) == []
def test_timeout_dag_errors_are_import_errors(self, tmp_path, caplog):
"""
Test that if the DAG contains Timeout error it will be still loaded to DB as import_errors
"""
code_to_save = """
# Define Dag to load
import datetime
import time
import airflow
from airflow.providers.standard.operators.python import PythonOperator
time.sleep(1)
with airflow.DAG(
"import_timeout",
start_date=datetime.datetime(2022, 1, 1),
schedule=None) as dag:
def f():
print("Sleeping")
time.sleep(1)
for ind in range(10):
PythonOperator(
dag=dag,
task_id=f"sleep_2_{ind}",
python_callable=f,
)
"""
with open("tmp_file.py", "w") as f:
f.write(code_to_save)
with conf_vars({("core", "DAGBAG_IMPORT_TIMEOUT"): "0.01"}):
dagbag = DagBag(dag_folder=os.fspath("tmp_file.py"), include_examples=False)
dag = dagbag._load_modules_from_file("tmp_file.py", safe_mode=False)
assert dag is not None
assert "tmp_file.py" in dagbag.import_errors
assert "DagBag import timeout for" in caplog.text
@staticmethod
def _make_test_traceback(unparseable_filename: str, depth=None) -> str:
python_311_marker = " ^^^^^^^^^^^\n" if PY311 else ""
python_313_marker = [" ~~~~~~~~~^^\n"] if PY313 else []
frames = (
f' File "{unparseable_filename}", line 3, in <module>\n something()\n',
*python_313_marker,
f' File "{unparseable_filename}", line 2, in something\n return airflow_DAG\n{python_311_marker}',
)
depth = 0 if depth is None else -depth
return (
"Traceback (most recent call last):\n"
+ "".join(frames[depth:])
+ "NameError: name 'airflow_DAG' is not defined\n"
)
@pytest.mark.parametrize("depth", (None, 1))
def test_import_error_tracebacks(self, tmp_path, depth):
unparseable_filename = tmp_path.joinpath("dag.py").as_posix()
with open(unparseable_filename, "w") as unparseable_file:
unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS)
with contextlib.ExitStack() as cm:
if depth is not None:
cm.enter_context(conf_vars({("core", "dagbag_import_error_traceback_depth"): str(depth)}))
dagbag = DagBag(dag_folder=unparseable_filename, include_examples=False)
import_errors = dagbag.import_errors
assert unparseable_filename in import_errors
assert import_errors[unparseable_filename] == self._make_test_traceback(unparseable_filename, depth)
@pytest.mark.parametrize("depth", (None, 1))
def test_import_error_tracebacks_zip(self, tmp_path, depth):
invalid_zip_filename = (tmp_path / "test_zip_invalid.zip").as_posix()
invalid_dag_filename = os.path.join(invalid_zip_filename, "dag.py")
with zipfile.ZipFile(invalid_zip_filename, "w") as invalid_zip_file:
invalid_zip_file.writestr("dag.py", INVALID_DAG_WITH_DEPTH_FILE_CONTENTS)
with contextlib.ExitStack() as cm:
if depth is not None:
cm.enter_context(conf_vars({("core", "dagbag_import_error_traceback_depth"): str(depth)}))
dagbag = DagBag(dag_folder=invalid_zip_filename, include_examples=False)
import_errors = dagbag.import_errors
assert invalid_dag_filename in import_errors
assert import_errors[invalid_dag_filename] == self._make_test_traceback(invalid_dag_filename, depth)
@patch("airflow.settings.task_policy", cluster_policies.example_task_policy)
def test_task_cluster_policy_violation(self):
"""
test that file processing results in import error when task does not
obey cluster policy.
"""
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py")
dag_id = "test_missing_owner"
err_cls_name = "AirflowClusterPolicyViolation"
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
assert set() == set(dagbag.dag_ids)
expected_import_errors = {
dag_file: (
f"""{err_cls_name}: DAG policy violation (DAG ID: {dag_id}, Path: {dag_file}):\n"""
"""Notices:\n"""
""" * Task must have non-None non-default owner. Current value: airflow"""
)
}
assert expected_import_errors == dagbag.import_errors
@patch("airflow.settings.task_policy", cluster_policies.example_task_policy)
def test_task_cluster_policy_nonstring_owner(self):
"""
test that file processing results in import error when task does not
obey cluster policy and has owner whose type is not string.
"""
TEST_DAGS_CORRUPTED_FOLDER = pathlib.Path(__file__).parent.with_name("dags_corrupted")
dag_file = os.path.join(TEST_DAGS_CORRUPTED_FOLDER, "test_nonstring_owner.py")
dag_id = "test_nonstring_owner"
err_cls_name = "AirflowClusterPolicyViolation"
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
assert set() == set(dagbag.dag_ids)
expected_import_errors = {
dag_file: (
f"""{err_cls_name}: DAG policy violation (DAG ID: {dag_id}, Path: {dag_file}):\n"""
"""Notices:\n"""
""" * owner should be a string. Current value: ['a']"""
)
}
assert expected_import_errors == dagbag.import_errors
@patch("airflow.settings.task_policy", cluster_policies.example_task_policy)
def test_task_cluster_policy_obeyed(self):
"""
test that dag successfully imported without import errors when tasks
obey cluster policy.
"""
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_with_non_default_owner.py")
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
assert {"test_with_non_default_owner"} == set(dagbag.dag_ids)
assert dagbag.import_errors == {}
@patch("airflow.settings.dag_policy", cluster_policies.dag_policy)
def test_dag_cluster_policy_obeyed(self):
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_dag_with_no_tags.py")
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
assert len(dagbag.dag_ids) == 0
assert "has no tags" in dagbag.import_errors[dag_file]
def test_dagbag_dag_collection(self):
dagbag = DagBag(
dag_folder=TEST_DAGS_FOLDER,
include_examples=False,
collect_dags=False,
bundle_name="test_collection",
)
# since collect_dags is False, dagbag.dags should be empty
assert not dagbag.dags
dagbag.collect_dags()
assert dagbag.dags
# test that dagbag.dags is not empty if collect_dags is True
dagbag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False, bundle_name="test_collection")
assert dagbag.dags
def test_dabgag_captured_warnings(self):
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_dag_warnings.py")
dagbag = DagBag(dag_folder=dag_file, include_examples=False, collect_dags=False)
assert dag_file not in dagbag.captured_warnings
dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False)
assert dagbag.dagbag_stats[0].warning_num == 2
assert dagbag.captured_warnings == {
dag_file: (
f"{dag_file}:46: DeprecationWarning: Deprecated Parameter",
f"{dag_file}:48: UserWarning: Some Warning",
)
}
with warnings.catch_warnings():
# Disable capture DeprecationWarning, and it should be reflected in captured warnings
warnings.simplefilter("ignore", DeprecationWarning)
dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False)
assert dag_file in dagbag.captured_warnings
assert len(dagbag.captured_warnings[dag_file]) == 1
assert dagbag.dagbag_stats[0].warning_num == 1
# Disable all warnings, no captured warnings expected
warnings.simplefilter("ignore")
dagbag.collect_dags(dag_folder=dagbag.dag_folder, include_examples=False, only_if_updated=False)
assert dag_file not in dagbag.captured_warnings
assert dagbag.dagbag_stats[0].warning_num == 0
@pytest.fixture
def warning_zipped_dag_path(self, tmp_path: pathlib.Path) -> str:
warnings_dag_file = TEST_DAGS_FOLDER / "test_dag_warnings.py"
zipped = tmp_path / "test_dag_warnings.zip"
with zipfile.ZipFile(zipped, "w") as zf:
zf.write(warnings_dag_file, warnings_dag_file.name)
return os.fspath(zipped)
def test_dabgag_captured_warnings_zip(self, warning_zipped_dag_path: str):
in_zip_dag_file = f"{warning_zipped_dag_path}/test_dag_warnings.py"
dagbag = DagBag(dag_folder=warning_zipped_dag_path, include_examples=False)
assert dagbag.dagbag_stats[0].warning_num == 2
assert dagbag.captured_warnings == {
warning_zipped_dag_path: (
f"{in_zip_dag_file}:46: DeprecationWarning: Deprecated Parameter",
f"{in_zip_dag_file}:48: UserWarning: Some Warning",
)
}
@pytest.mark.parametrize(
("known_pools", "expected"),
(
pytest.param(None, set(), id="disabled"),
pytest.param(
{"default_pool"},
{
DagWarning(
"test",
DagWarningType.NONEXISTENT_POOL,
"Dag 'test' references non-existent pools: ['pool1']",
),
},
id="only-default",
),
pytest.param(
{"default_pool", "pool1"},
set(),
id="known-pools",
),
),
)
def test_dag_warnings_invalid_pool(self, known_pools, expected):
with DAG(dag_id="test") as dag:
BaseOperator(task_id="1")
BaseOperator(task_id="2", pool="pool1")
dagbag = DagBag(dag_folder="", include_examples=False, collect_dags=False, known_pools=known_pools)
dagbag.bag_dag(dag)
assert dagbag.dag_warnings == expected
def test_sigsegv_handling(self, tmp_path, caplog):
"""
Test that a SIGSEGV in a DAG file is handled gracefully and does not crash the process.
"""
# Create a DAG file that will raise a SIGSEGV
dag_file = tmp_path / "bad_dag.py"
dag_file.write_text(
textwrap.dedent(
"""\
import signal
from airflow import DAG
import os
from airflow.decorators import task
os.kill(os.getpid(), signal.SIGSEGV)
with DAG('testbug'):
@task
def mytask():
print(1)
mytask()
"""
)
)
dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert "Received SIGSEGV signal while processing" in caplog.text
assert dag_file.as_posix() in dagbag.import_errors
def test_failed_signal_registration_does_not_crash_the_process(self, tmp_path, caplog):
"""Test that a ValueError raised by a signal setting on child process does not crash the main process.
This was raised in test_dag_report.py module in api_fastapi/core_api/routes/public tests
"""
dag_file = tmp_path / "test_dag.py"
dag_file.write_text(
textwrap.dedent(
"""\
from airflow import DAG
from airflow.decorators import task
with DAG('testbug'):
@task
def mytask():
print(1)
mytask()
"""
)
)
with mock.patch("airflow.dag_processing.dagbag.signal.signal") as mock_signal:
mock_signal.side_effect = ValueError("Invalid signal setting")
DagBag(dag_folder=os.fspath(tmp_path), include_examples=False)
assert "SIGSEGV signal handler registration failed. Not in the main thread" in caplog.text
class TestCaptureWithReraise:
@staticmethod
def raise_warnings():
warnings.warn("Foo", UserWarning, stacklevel=2)
warnings.warn("Bar", UserWarning, stacklevel=2)
warnings.warn("Baz", UserWarning, stacklevel=2)
def test_capture_no_warnings(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
with _capture_with_reraise() as cw:
pass
assert cw == []
def test_capture_warnings(self):
with pytest.warns(UserWarning, match="(Foo|Bar|Baz)") as ctx:
with _capture_with_reraise() as cw:
self.raise_warnings()
assert len(cw) == 3
assert len(ctx.list) == 3
def test_capture_warnings_with_parent_error_filter(self):
with warnings.catch_warnings(record=True) as records:
warnings.filterwarnings("error", message="Bar")
with _capture_with_reraise() as cw:
with pytest.raises(UserWarning, match="Bar"):
self.raise_warnings()
assert len(cw) == 1
assert len(records) == 1
def test_capture_warnings_with_parent_ignore_filter(self):
with warnings.catch_warnings(record=True) as records:
warnings.filterwarnings("ignore", message="Baz")
with _capture_with_reraise() as cw:
self.raise_warnings()
assert len(cw) == 2
assert len(records) == 2
def test_capture_warnings_with_filters(self):
with warnings.catch_warnings(record=True) as records:
with _capture_with_reraise() as cw:
warnings.filterwarnings("ignore", message="Foo")
self.raise_warnings()
assert len(cw) == 2
assert len(records) == 2
def test_capture_warnings_with_error_filters(self):
with warnings.catch_warnings(record=True) as records:
with _capture_with_reraise() as cw:
warnings.filterwarnings("error", message="Bar")
with pytest.raises(UserWarning, match="Bar"):
self.raise_warnings()
assert len(cw) == 1
assert len(records) == 1