blob: c1ef3b72c92f4c1ae5a1d87aebb99a57ebfed34d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, NoReturn, Protocol
import pytest
from tests_common.test_utils.config import conf_vars
pytest_plugins = "tests_common.pytest_plugin"
# Task SDK does not need access to the Airflow database
os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
os.environ["_AIRFLOW__AS_LIBRARY"] = "true"
if TYPE_CHECKING:
from collections.abc import Sequence
from datetime import datetime
from structlog.typing import EventDict, WrappedLogger
from airflow.sdk.api.datamodels._generated import AssetEventDagRunReference, TIRunContext
@pytest.hookimpl()
def pytest_addhooks(pluginmanager: pytest.PytestPluginManager):
# Python 3.12 starts warning about mixing os.fork + Threads, and the pytest-rerunfailures plugin uses
# threads internally. Since this is new code, and it should be flake free, we disable the re-run failures
# plugin early (so that it doesn't run it's pytest_configure which is where the thread starts up if xdist
# is discovered).
pluginmanager.set_blocked("rerunfailures")
@pytest.hookimpl(tryfirst=True)
def pytest_configure(config: pytest.Config) -> None:
config.inicfg["airflow_deprecations_ignore"] = []
# Always skip looking for tests in these folders!
config.addinivalue_line("norecursedirs", "tests/test_dags")
config.addinivalue_line("markers", "log_level: ")
import airflow.settings
airflow.settings.get_policy_plugin_manager()
@pytest.fixture(scope="session", autouse=True)
def _init_log():
from airflow.sdk.log import configure_logging
configure_logging()
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
if next(item.iter_markers(name="db_test"), None):
pytest.fail("Task SDK tests must not use database")
class LogCapture:
# Like structlog.typing.LogCapture, but that doesn't add log_level in to the event dict
entries: list[EventDict | bytes]
def __init__(self) -> None:
self.entries = []
def __call__(self, _: WrappedLogger, method_name: str, event: EventDict | bytes) -> NoReturn:
from structlog.exceptions import DropEvent
if isinstance(event, dict):
if "level" not in event:
event["_log_level"] = method_name
self.entries.append(event)
raise DropEvent
@pytest.fixture
def test_dags_dir():
return Path(__file__).parent.joinpath("task_sdk", "dags")
@pytest.fixture
def captured_logs(request, monkeypatch):
import structlog
import structlog.processors
from airflow.sdk._shared.logging.structlog import PER_LOGGER_LEVELS
from airflow.sdk.log import configure_logging, reset_logging
# Use our real log config
reset_logging()
configure_logging(json_output=False, colored_console_log=False)
# Get log level from test parameter, which can either be a single log level or a
# tuple of log level and desired output type, defaulting to INFO if not provided
log_level = logging.INFO
output = "dict"
param = getattr(request, "param", None)
if not param:
mark = next(request.node.iter_markers(name="log_level"), None)
param = mark.args[0] if mark is not None else None
if isinstance(param, int):
log_level = param
elif isinstance(param, tuple):
log_level = param[0]
output = param[1]
monkeypatch.setitem(PER_LOGGER_LEVELS, "", log_level)
cur_processors = structlog.get_config()["processors"]
processors = cur_processors.copy()
if not any(isinstance(proc, structlog.processors.MaybeTimeStamper) for proc in processors):
timestamper = structlog.processors.MaybeTimeStamper(fmt="iso")
processors.append(timestamper)
if output == "dict":
# We need to replace remove the last processor (the one that turns JSON into text, as we want the
# event dict for tests)
proc = processors.pop()
assert isinstance(proc, (structlog.dev.ConsoleRenderer, structlog.processors.JSONRenderer)), (
"Pre-condition"
)
try:
cap = LogCapture()
processors.append(cap)
structlog.configure(processors=processors)
task_logger = logging.getLogger("airflow.task")
from airflow.sdk._shared.secrets_masker import SecretsMasker
task_logger.addFilter(SecretsMasker())
yield cap.entries
finally:
structlog.configure(processors=cur_processors)
@pytest.fixture(autouse=True, scope="session")
def _disable_ol_plugin():
# The OpenLineage plugin imports setproctitle, and that now causes (C) level thread calls, which on Py
# 3.12+ issues a warning when os.fork happens. So for this plugin we disable it
# And we load plugins when setting the priority_weight field
import airflow.sdk.plugins_manager
old = airflow.sdk.plugins_manager._get_plugins
airflow.sdk.plugins_manager._get_plugins = lambda: ([], {})
yield
airflow.sdk.plugins_manager._get_plugins = old
@pytest.fixture(autouse=True)
def _cleanup_async_resources(request):
"""
Clean up async resources that can cause Python 3.12 fork warnings.
Problem: asgiref.sync.sync_to_async (used in _async_get_connection) creates
ThreadPoolExecutors that persist between tests. When supervisor.py calls
os.fork() in subsequent tests, Python 3.12+ warns about forking a
multi-threaded process.
Solution: Clean up asgiref's ThreadPoolExecutors after async tests to ensure
subsequent tests start with a clean thread environment.
"""
yield
# Only clean up after async tests to avoid unnecessary overhead
if "asyncio" in request.keywords:
# Clean up asgiref ThreadPoolExecutors that persist between tests
# These are created by sync_to_async() calls in async connection retrieval
try:
from asgiref.sync import SyncToAsync
# SyncToAsync maintains a class-level executor for performance
# We need to shut it down to prevent multi-threading warnings on fork()
if hasattr(SyncToAsync, "single_thread_executor") and SyncToAsync.single_thread_executor:
if not SyncToAsync.single_thread_executor._shutdown:
SyncToAsync.single_thread_executor.shutdown(wait=True)
SyncToAsync.single_thread_executor = None
# SyncToAsync also maintains a WeakKeyDictionary of context-specific executors
# Clean these up too to ensure complete thread cleanup
if hasattr(SyncToAsync, "context_to_thread_executor"):
for executor in list(SyncToAsync.context_to_thread_executor.values()):
if hasattr(executor, "shutdown") and not getattr(executor, "_shutdown", True):
executor.shutdown(wait=True)
SyncToAsync.context_to_thread_executor.clear()
except (ImportError, AttributeError):
# If asgiref structure changes, fail gracefully
pass
class MakeTIContextCallable(Protocol):
def __call__(
self,
dag_id: str = ...,
run_id: str = ...,
logical_date: str | datetime = ...,
data_interval_start: str | datetime = ...,
data_interval_end: str | datetime = ...,
clear_number: int = ...,
start_date: str | datetime = ...,
run_after: str | datetime = ...,
run_type: str = ...,
task_reschedule_count: int = ...,
conf: dict[str, Any] | None = ...,
should_retry: bool = ...,
max_tries: int = ...,
consumed_asset_events: Sequence[AssetEventDagRunReference] = ...,
) -> TIRunContext: ...
class MakeTIContextDictCallable(Protocol):
def __call__(
self,
dag_id: str = ...,
run_id: str = ...,
logical_date: str = ...,
data_interval_start: str | datetime = ...,
data_interval_end: str | datetime = ...,
clear_number: int = ...,
start_date: str | datetime = ...,
run_after: str | datetime = ...,
run_type: str = ...,
task_reschedule_count: int = ...,
conf=None,
consumed_asset_events: Sequence[AssetEventDagRunReference] = ...,
) -> dict[str, Any]: ...
@pytest.fixture
def make_ti_context() -> MakeTIContextCallable:
"""Factory for creating TIRunContext objects."""
from airflow.sdk import DagRunState
from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
def _make_context(
dag_id: str = "test_dag",
run_id: str = "test_run",
logical_date: str | datetime = "2024-12-01T01:00:00Z",
data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
clear_number: int = 0,
start_date: str | datetime = "2024-12-01T01:00:00Z",
run_after: str | datetime = "2024-12-01T01:00:00Z",
run_type: str = "manual",
task_reschedule_count: int = 0,
conf: dict[str, Any] | None = None,
should_retry: bool = False,
max_tries: int = 0,
consumed_asset_events: Sequence[AssetEventDagRunReference] = (),
) -> TIRunContext:
return TIRunContext(
dag_run=DagRun(
dag_id=dag_id,
run_id=run_id,
logical_date=logical_date, # type: ignore
data_interval_start=data_interval_start, # type: ignore
data_interval_end=data_interval_end, # type: ignore
clear_number=clear_number, # type: ignore
start_date=start_date, # type: ignore
run_type=run_type, # type: ignore
run_after=run_after, # type: ignore
state=DagRunState.RUNNING,
conf=conf, # type: ignore
consumed_asset_events=list(consumed_asset_events),
),
task_reschedule_count=task_reschedule_count,
max_tries=max_tries,
should_retry=should_retry,
)
return _make_context
@pytest.fixture
def make_ti_context_dict(make_ti_context: MakeTIContextCallable) -> MakeTIContextDictCallable:
"""Factory for creating context dictionaries suited for API Server response."""
def _make_context_dict(
dag_id: str = "test_dag",
run_id: str = "test_run",
logical_date: str | datetime = "2024-12-01T00:00:00Z",
data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
clear_number: int = 0,
start_date: str | datetime = "2024-12-01T00:00:00Z",
run_after: str | datetime = "2024-12-01T00:00:00Z",
run_type: str = "manual",
task_reschedule_count: int = 0,
conf=None,
consumed_asset_events: Sequence[AssetEventDagRunReference] = (),
) -> dict[str, Any]:
context = make_ti_context(
dag_id=dag_id,
run_id=run_id,
logical_date=logical_date,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
clear_number=clear_number,
start_date=start_date,
run_after=run_after,
run_type=run_type,
conf=conf,
task_reschedule_count=task_reschedule_count,
consumed_asset_events=consumed_asset_events,
)
return context.model_dump(exclude_unset=True, mode="json")
return _make_context_dict
@pytest.fixture(scope="class", autouse=True)
def allow_test_classes_deserialization():
"""
Allow test classes and airflow SDK classes to be deserialized. In airflow-core tests, this is provided by
unit_tests.cfg which sets allowed_deserialization_classes = airflow.* tests.*
SDK tests may not inherit that configuration, so we explicitly allow airflow.sdk.* and tests.* here.
"""
with conf_vars({("core", "allowed_deserialization_classes"): "airflow.sdk.* tests.*"}):
yield