blob: da6600a6fda5d1f1c6f6a883c1c635062c589a01 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
import sys
import warnings
from unittest.mock import patch
import pytest
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.sdk._shared.providers_discovery import (
HookClassProvider,
LazyDictWithCache,
ProviderInfo,
)
from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker, skip_if_not_on_main
from tests_common.test_utils.paths import AIRFLOW_ROOT_PATH
PY313 = sys.version_info >= (3, 13)
def test_cleanup_providers_manager_runtime(cleanup_providers_manager):
"""Check the cleanup provider manager functionality."""
provider_manager = ProvidersManagerTaskRuntime()
# Check by type name since symlinks create different module paths
assert type(provider_manager.hooks).__name__ == "LazyDictWithCache"
hooks = provider_manager.hooks
ProvidersManagerTaskRuntime()._cleanup()
assert not len(hooks)
assert ProvidersManagerTaskRuntime().hooks is hooks
@skip_if_force_lowest_dependencies_marker
class TestProvidersManagerRuntime:
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog, cleanup_providers_manager_runtime):
self._caplog = caplog
def test_hooks_deprecation_warnings_generated(self):
providers_manager = ProvidersManagerTaskRuntime()
providers_manager._provider_dict["test-package"] = ProviderInfo(
version="0.0.1",
data={"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"]},
)
with pytest.warns(expected_warning=DeprecationWarning, match="hook-class-names") as warning_records:
providers_manager._discover_hooks()
assert warning_records
def test_hooks_deprecation_warnings_not_generated(self):
with warnings.catch_warnings(record=True) as warning_records:
providers_manager = ProvidersManagerTaskRuntime()
providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo(
version="0.0.1",
data={
"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"],
"connection-types": [
{
"hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook",
"connection-type": "sftp",
}
],
},
)
providers_manager._discover_hooks()
assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == []
def test_warning_logs_generated(self):
providers_manager = ProvidersManagerTaskRuntime()
providers_manager._hooks_lazy_dict = LazyDictWithCache()
with self._caplog.at_level(logging.WARNING):
providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo(
version="0.0.1",
data={
"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"],
"connection-types": [
{
"hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook",
"connection-type": "wrong-connection-type",
}
],
},
)
providers_manager._discover_hooks()
_ = providers_manager._hooks_lazy_dict["wrong-connection-type"]
assert len(self._caplog.entries) == 1
assert "Inconsistency!" in self._caplog[0]["event"]
assert "sftp" not in providers_manager._hooks_lazy_dict
def test_warning_logs_not_generated(self):
with self._caplog.at_level(logging.WARNING):
providers_manager = ProvidersManagerTaskRuntime()
providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo(
version="0.0.1",
data={
"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"],
"connection-types": [
{
"hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook",
"connection-type": "sftp",
}
],
},
)
providers_manager._discover_hooks()
_ = providers_manager._hooks_lazy_dict["sftp"]
assert not self._caplog.records
assert "sftp" in providers_manager.hooks
def test_already_registered_conn_type_in_provide(self):
with self._caplog.at_level(logging.WARNING):
providers_manager = ProvidersManagerTaskRuntime()
providers_manager._provider_dict["apache-airflow-providers-dummy"] = ProviderInfo(
version="0.0.1",
data={
"connection-types": [
{
"hook-class-name": "airflow.providers.dummy.hooks.dummy.DummyHook",
"connection-type": "dummy",
},
{
"hook-class-name": "airflow.providers.dummy.hooks.dummy.DummyHook2",
"connection-type": "dummy",
},
],
},
)
providers_manager._discover_hooks()
_ = providers_manager._hooks_lazy_dict["dummy"]
assert len(self._caplog.records) == 1
msg = self._caplog.messages[0]
assert msg.startswith("The connection type 'dummy' is already registered")
assert (
"different class names: 'airflow.providers.dummy.hooks.dummy.DummyHook'"
" and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
) in msg
def test_hooks(self):
with warnings.catch_warnings(record=True) as warning_records:
with self._caplog.at_level(logging.WARNING):
provider_manager = ProvidersManagerTaskRuntime()
connections_list = list(provider_manager.hooks.keys())
assert len(connections_list) > 60
if len(self._caplog.records) != 0:
for record in self._caplog.records:
print(record.message, file=sys.stderr)
print(record.exc_info, file=sys.stderr)
raise AssertionError("There are warnings generated during hook imports. Please fix them")
assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == []
@skip_if_not_on_main
@pytest.mark.execution_timeout(150)
def test_hook_values(self):
provider_dependencies = json.loads(
(AIRFLOW_ROOT_PATH / "generated" / "provider_dependencies.json").read_text()
)
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
excluded_providers: list[str] = []
for provider_name, provider_info in provider_dependencies.items():
if python_version in provider_info.get("excluded-python-versions", []):
excluded_providers.append(f"apache-airflow-providers-{provider_name.replace('.', '-')}")
with warnings.catch_warnings(record=True) as warning_records:
with self._caplog.at_level(logging.WARNING):
provider_manager = ProvidersManagerTaskRuntime()
connections_list = list(provider_manager.hooks.values())
assert len(connections_list) > 60
if len(self._caplog.records) != 0:
real_warning_count = 0
for record in self._caplog.entries:
# When there is error importing provider that is excluded the provider name is in the message
if any(excluded_provider in record["event"] for excluded_provider in excluded_providers):
continue
print(record["event"], file=sys.stderr)
print(record.get("exc_info"), file=sys.stderr)
real_warning_count += 1
if real_warning_count:
if PY313:
only_ydb_and_yandexcloud_warnings = True
for record in warning_records:
if "ydb" in str(record.message) or "yandexcloud" in str(record.message):
continue
only_ydb_and_yandexcloud_warnings = False
if only_ydb_and_yandexcloud_warnings:
print(
"Only warnings from ydb and yandexcloud providers are generated, "
"which is expected in Python 3.13+",
file=sys.stderr,
)
return
raise AssertionError("There are warnings generated during hook imports. Please fix them")
assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == []
@patch("airflow.sdk.providers_manager_runtime.import_string")
def test_optional_feature_no_warning(self, mock_importlib_import_string):
with self._caplog.at_level(logging.WARNING):
mock_importlib_import_string.side_effect = AirflowOptionalProviderFeatureException()
providers_manager = ProvidersManagerTaskRuntime()
providers_manager._hook_provider_dict["test_connection"] = HookClassProvider(
package_name="test_package", hook_class_name="HookClass"
)
providers_manager._import_hook(
hook_class_name=None, provider_info=None, package_name=None, connection_type="test_connection"
)
assert self._caplog.messages == []
@patch("airflow.sdk.providers_manager_runtime.import_string")
def test_optional_feature_debug(self, mock_importlib_import_string):
with self._caplog.at_level(logging.INFO):
mock_importlib_import_string.side_effect = AirflowOptionalProviderFeatureException()
providers_manager = ProvidersManagerTaskRuntime()
providers_manager._hook_provider_dict["test_connection"] = HookClassProvider(
package_name="test_package", hook_class_name="HookClass"
)
providers_manager._import_hook(
hook_class_name=None, provider_info=None, package_name=None, connection_type="test_connection"
)
assert self._caplog.messages == [
"Optional provider feature disabled when importing 'HookClass' from 'test_package' package"
]