plugins: load plugins from providers (#32692)
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
diff --git a/airflow/__init__.py b/airflow/__init__.py
index d330b93..2e8e088 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -106,11 +106,6 @@
return val
-if not settings.LAZY_LOAD_PLUGINS:
- from airflow import plugins_manager
-
- plugins_manager.ensure_plugins_loaded()
-
if not settings.LAZY_LOAD_PROVIDERS:
from airflow import providers_manager
@@ -118,6 +113,10 @@
manager.initialize_providers_list()
manager.initialize_providers_hooks()
manager.initialize_providers_extra_links()
+if not settings.LAZY_LOAD_PLUGINS:
+ from airflow import plugins_manager
+
+ plugins_manager.ensure_plugins_loaded()
# This is never executed, but tricks static analyzers (PyDev, PyCharm,)
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index 946f064..17acdd2 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -38,7 +38,7 @@
from airflow import settings
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
-from airflow.utils.module_loading import qualname
+from airflow.utils.module_loading import import_string, qualname
if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
@@ -50,6 +50,7 @@
import_errors: dict[str, str] = {}
plugins: list[AirflowPlugin] | None = None
+loaded_plugins: set[str] = set()
# Plugin components to integrate as modules
registered_hooks: list[BaseHook] | None = None
@@ -205,10 +206,16 @@
def register_plugin(plugin_instance):
"""
Start plugin load and register it after success initialization.
+ If plugin is already registered, do nothing.
:param plugin_instance: subclass of AirflowPlugin
"""
global plugins
+
+ if plugin_instance.name in loaded_plugins:
+ return
+
+ loaded_plugins.add(plugin_instance.name)
plugin_instance.on_load()
plugins.append(plugin_instance)
@@ -267,6 +274,26 @@
import_errors[file_path] = str(e)
+def load_providers_plugins():
+ from airflow.providers_manager import ProvidersManager
+
+ log.debug("Loading plugins from providers")
+ providers_manager = ProvidersManager()
+ providers_manager.initialize_providers_plugins()
+ for plugin in providers_manager.plugins:
+ log.debug("Importing plugin %s from class %s", plugin.name, plugin.plugin_class)
+
+ try:
+ plugin_instance = import_string(plugin.plugin_class)
+ if not is_valid_plugin(plugin_instance):
+ log.warning("Plugin %s is not a valid plugin", plugin.name)
+ continue
+ register_plugin(plugin_instance)
+ except ImportError:
+ log.exception("Failed to load plugin %s from class name %s", plugin.name, plugin.plugin_class)
+ continue
+
+
def make_module(name: str, objects: list[Any]):
"""Creates new module."""
if not objects:
@@ -306,6 +333,9 @@
load_plugins_from_plugin_directory()
load_entrypoint_plugins()
+ if not settings.LAZY_LOAD_PROVIDERS:
+ load_providers_plugins()
+
# We don't do anything with these for now, but we want to keep track of
# them so we can integrate them in to the UI's Connection screens
for plugin in plugins:
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index bda6df3..b7689a92d 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -217,6 +217,14 @@
integration_name: str
+class PluginInfo(NamedTuple):
+ """Plugin class, name and provider it comes from."""
+
+ name: str
+ plugin_class: str
+ provider_name: str
+
+
class HookInfo(NamedTuple):
"""Hook information."""
@@ -421,6 +429,8 @@
self._customized_form_fields_schema_validator = (
_create_customized_form_field_behaviours_schema_validator()
)
+ # Set of plugins contained in providers
+ self._plugins_set: set[PluginInfo] = set()
@provider_info_cache("list")
def initialize_providers_list(self):
@@ -516,6 +526,11 @@
self.initialize_providers_list()
self._discover_auth_backends()
+ @provider_info_cache("plugins")
+ def initialize_providers_plugins(self):
+ self.initialize_providers_list()
+ self._discover_plugins()
+
def _discover_all_providers_from_packages(self) -> None:
"""
Discover all providers by scanning packages installed.
@@ -1024,6 +1039,21 @@
if provider.data.get("config"):
self._provider_configs[provider_package] = provider.data.get("config")
+ def _discover_plugins(self) -> None:
+ """Retrieve all plugins defined in the providers."""
+ for provider_package, provider in self._provider_dict.items():
+ for plugin_dict in provider.data.get("plugins", ()):
+ if not _correctness_check(provider_package, plugin_dict["plugin-class"], provider):
+ log.warning("Plugin not loaded due to above correctness check problem.")
+ continue
+ self._plugins_set.add(
+ PluginInfo(
+ name=plugin_dict["name"],
+ plugin_class=plugin_dict["plugin-class"],
+ provider_name=provider_package,
+ )
+ )
+
@provider_info_cache("triggers")
def initialize_providers_triggers(self):
"""Initialization of providers triggers."""
@@ -1063,6 +1093,12 @@
return self._hooks_lazy_dict
@property
+ def plugins(self) -> list[PluginInfo]:
+ """Returns information about plugins available in providers."""
+ self.initialize_providers_plugins()
+ return sorted(self._plugins_set, key=lambda x: x.plugin_class)
+
+ @property
def taskflow_decorators(self) -> dict[str, TaskDecorator]:
self.initialize_providers_taskflow_decorator()
return self._taskflow_decorators
diff --git a/tests/always/test_providers_manager.py b/tests/always/test_providers_manager.py
index b99dbcb..7e05d1c 100644
--- a/tests/always/test_providers_manager.py
+++ b/tests/always/test_providers_manager.py
@@ -28,7 +28,13 @@
from wtforms import BooleanField, Field, StringField
from airflow.exceptions import AirflowOptionalProviderFeatureException
-from airflow.providers_manager import HookClassProvider, LazyDictWithCache, ProviderInfo, ProvidersManager
+from airflow.providers_manager import (
+ HookClassProvider,
+ LazyDictWithCache,
+ PluginInfo,
+ ProviderInfo,
+ ProvidersManager,
+)
class TestProviderManager:
@@ -157,6 +163,28 @@
" and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
) in self._caplog.records[0].message
+ def test_providers_manager_register_plugins(self):
+ providers_manager = ProvidersManager()
+ providers_manager._provider_dict["apache-airflow-providers-apache-hive"] = ProviderInfo(
+ version="0.0.1",
+ data={
+ "plugins": [
+ {
+ "name": "plugin1",
+ "plugin-class": "airflow.providers.apache.hive.plugins.hive.HivePlugin",
+ }
+ ]
+ },
+ package_or_source="package",
+ )
+ providers_manager._discover_plugins()
+ assert len(providers_manager._plugins_set) == 1
+ assert providers_manager._plugins_set.pop() == PluginInfo(
+ name="plugin1",
+ plugin_class="airflow.providers.apache.hive.plugins.hive.HivePlugin",
+ provider_name="apache-airflow-providers-apache-hive",
+ )
+
def test_hooks(self):
with pytest.warns(expected_warning=None) as warning_records:
with self._caplog.at_level(logging.WARNING):
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index 019e2a6..bf74863 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -160,6 +160,13 @@
class TestPluginsManager:
+ @pytest.fixture(autouse=True, scope="function")
+ def clean_plugins(self):
+ from airflow import plugins_manager
+
+ plugins_manager.loaded_plugins = set()
+ plugins_manager.plugins = []
+
def test_no_log_when_no_plugins(self, caplog):
with mock_plugin_manager(plugins=[]):
@@ -378,6 +385,32 @@
assert get_listener_manager().has_listeners
assert get_listener_manager().pm.get_plugins().pop().__name__ == "tests.listeners.empty_listener"
+ def test_should_import_plugin_from_providers(self):
+ from airflow import plugins_manager
+
+ with mock.patch("airflow.plugins_manager.plugins", []):
+ assert len(plugins_manager.plugins) == 0
+ plugins_manager.load_providers_plugins()
+ assert len(plugins_manager.plugins) >= 2
+
+ def test_does_not_double_import_entrypoint_provider_plugins(self):
+ from airflow import plugins_manager
+
+ mock_entrypoint = mock.Mock()
+ mock_entrypoint.name = "test-entrypoint-plugin"
+ mock_entrypoint.module = "module_name_plugin"
+
+ mock_dist = mock.Mock()
+ mock_dist.metadata = {"Name": "test-entrypoint-plugin"}
+ mock_dist.version = "1.0.0"
+ mock_dist.entry_points = [mock_entrypoint]
+
+ with mock.patch("airflow.plugins_manager.plugins", []):
+ assert len(plugins_manager.plugins) == 0
+ plugins_manager.load_entrypoint_plugins()
+ plugins_manager.load_providers_plugins()
+ assert len(plugins_manager.plugins) == 2
+
class TestPluginsDirectorySource:
def test_should_return_correct_path_name(self):