blob: 39c582249088d38618b5b53f013bb238b8af6fe4 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Manages all providers."""
from __future__ import annotations
import fnmatch
import functools
import json
import logging
import os
import sys
import warnings
from collections import OrderedDict
from dataclasses import dataclass
from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast
from packaging.utils import canonicalize_name
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.typing_compat import Literal
from airflow.utils import yaml
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
log = logging.getLogger(__name__)
if sys.version_info >= (3, 9):
from importlib.resources import files as resource_files
else:
from importlib_resources import files as resource_files
MIN_PROVIDER_VERSIONS = {
"apache-airflow-providers-celery": "2.1.0",
}
def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: str):
"""
Verify the correct placeholder prefix.
If the given field_behaviors dict contains a placeholder's node, and there
are placeholders for extra fields (i.e. anything other than the built-in conn
attrs), and if those extra fields are unprefixed, then add the prefix.
The reason we need to do this is, all custom conn fields live in the same dictionary,
so we need to namespace them with a prefix internally. But for user convenience,
and consistency between the `get_ui_field_behaviour` method and the extra dict itself,
we allow users to supply the unprefixed name.
"""
conn_attrs = {"host", "schema", "login", "password", "port", "extra"}
def ensure_prefix(field):
if field not in conn_attrs and not field.startswith("extra__"):
return f"extra__{conn_type}__{field}"
else:
return field
if "placeholders" in field_behaviors:
placeholders = field_behaviors["placeholders"]
field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors
if TYPE_CHECKING:
from airflow.decorators.base import TaskDecorator
from airflow.hooks.base import BaseHook
class LazyDictWithCache(MutableMapping):
"""
Lazy-loaded cached dictionary.
Dictionary, which in case you set callable, executes the passed callable with `key` attribute
at first use - and returns and caches the result.
"""
__slots__ = ["_resolved", "_raw_dict"]
def __init__(self, *args, **kw):
self._resolved = set()
self._raw_dict = dict(*args, **kw)
def __setitem__(self, key, value):
self._raw_dict.__setitem__(key, value)
def __getitem__(self, key):
value = self._raw_dict.__getitem__(key)
if key not in self._resolved and callable(value):
# exchange callable with result of calling it -- but only once! allow resolver to return a
# callable itself
value = value()
self._resolved.add(key)
if value:
self._raw_dict.__setitem__(key, value)
return value
def __delitem__(self, key):
self._raw_dict.__delitem__(key)
try:
self._resolved.remove(key)
except KeyError:
pass
def __iter__(self):
return iter(self._raw_dict)
def __len__(self):
return len(self._raw_dict)
def __contains__(self, key):
return key in self._raw_dict
def _create_provider_info_schema_validator():
"""Creates JSON schema validator from the provider_info.schema.json."""
import jsonschema
with resource_files("airflow").joinpath("provider_info.schema.json").open("rb") as f:
schema = json.load(f)
cls = jsonschema.validators.validator_for(schema)
validator = cls(schema)
return validator
def _create_customized_form_field_behaviours_schema_validator():
"""Creates JSON schema validator from the customized_form_field_behaviours.schema.json."""
import jsonschema
with resource_files("airflow").joinpath("customized_form_field_behaviours.schema.json").open("rb") as f:
schema = json.load(f)
cls = jsonschema.validators.validator_for(schema)
validator = cls(schema)
return validator
def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool:
if provider_package.startswith("apache-airflow"):
provider_path = provider_package[len("apache-") :].replace("-", ".")
if not class_name.startswith(provider_path):
log.warning(
"Coherence check failed when importing '%s' from '%s' package. It should start with '%s'",
class_name,
provider_package,
provider_path,
)
return False
return True
@dataclass
class ProviderInfo:
"""
Provider information.
:param version: version string
:param data: dictionary with information about the provider
:param source_or_package: whether the provider is source files or PyPI package. When installed from
sources we suppress provider import errors.
"""
version: str
data: dict
package_or_source: Literal["source"] | Literal["package"]
def __post_init__(self):
if self.package_or_source not in ("source", "package"):
raise ValueError(
f"Received {self.package_or_source!r} for `package_or_source`. "
"Must be either 'package' or 'source'."
)
self.is_source = self.package_or_source == "source"
class HookClassProvider(NamedTuple):
"""Hook class and Provider it comes from."""
hook_class_name: str
package_name: str
class TriggerInfo(NamedTuple):
"""Trigger class and provider it comes from."""
trigger_class_name: str
package_name: str
integration_name: str
class HookInfo(NamedTuple):
"""Hook information."""
hook_class_name: str
connection_id_attribute_name: str
package_name: str
hook_name: str
connection_type: str
connection_testable: bool
class ConnectionFormWidgetInfo(NamedTuple):
"""Connection Form Widget information."""
hook_class_name: str
package_name: str
field: Any
field_name: str
T = TypeVar("T", bound=Callable)
logger = logging.getLogger(__name__)
def log_debug_import_from_sources(class_name, e, provider_package):
"""Log debug imports from sources."""
log.debug(
"Optional feature disabled on exception when importing '%s' from '%s' package",
class_name,
provider_package,
exc_info=e,
)
def log_optional_feature_disabled(class_name, e, provider_package):
"""Log optional feature disabled."""
log.debug(
"Optional feature disabled on exception when importing '%s' from '%s' package",
class_name,
provider_package,
exc_info=e,
)
log.info(
"Optional provider feature disabled when importing '%s' from '%s' package",
class_name,
provider_package,
)
def log_import_warning(class_name, e, provider_package):
"""Log import warning."""
log.warning(
"Exception when importing '%s' from '%s' package",
class_name,
provider_package,
exc_info=e,
)
# This is a temporary measure until all community providers will add AirflowOptionalProviderFeatureException
# where they have optional features. We are going to add tests in our CI to catch all such cases and will
# fix them, but until now all "known unhandled optional feature errors" from community providers
# should be added here
KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google", "No module named 'paramiko'")]
def _sanity_check(
provider_package: str, class_name: str, provider_info: ProviderInfo
) -> type[BaseHook] | None:
"""
Performs coherence check on provider classes.
For apache-airflow providers - it checks if it starts with appropriate package. For all providers
it tries to import the provider - checking that there are no exceptions during importing.
It logs appropriate warning in case it detects any problems.
:param provider_package: name of the provider package
:param class_name: name of the class to import
:return the class if the class is OK, None otherwise.
"""
if not _check_builtin_provider_prefix(provider_package, class_name):
return None
try:
imported_class = import_string(class_name)
except AirflowOptionalProviderFeatureException as e:
# When the provider class raises AirflowOptionalProviderFeatureException
# this is an expected case when only some classes in provider are
# available. We just log debug level here and print info message in logs so that
# the user is aware of it
log_optional_feature_disabled(class_name, e, provider_package)
return None
except ImportError as e:
if provider_info.is_source:
# When we have providers from sources, then we just turn all import logs to debug logs
# As this is pretty expected that you have a number of dependencies not installed
# (we always have all providers from sources until we split providers to separate repo)
log_debug_import_from_sources(class_name, e, provider_package)
return None
if "No module named 'airflow.providers." in e.msg:
# handle cases where another provider is missing. This can only happen if
# there is an optional feature, so we log debug and print information about it
log_optional_feature_disabled(class_name, e, provider_package)
return None
for known_error in KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS:
# Until we convert all providers to use AirflowOptionalProviderFeatureException
# we assume any problem with importing another "provider" is because this is an
# optional feature, so we log debug and print information about it
if known_error[0] == provider_package and known_error[1] in e.msg:
log_optional_feature_disabled(class_name, e, provider_package)
return None
# But when we have no idea - we print warning to logs
log_import_warning(class_name, e, provider_package)
return None
except Exception as e:
log_import_warning(class_name, e, provider_package)
return None
return imported_class
# We want to have better control over initialization of parameters and be able to debug and test it
# So we add our own decorator
def provider_info_cache(cache_name: str) -> Callable[[T], T]:
"""
Decorate and cache provider info.
Decorator factory that create decorator that caches initialization of provider's parameters
:param cache_name: Name of the cache
"""
def provider_info_cache_decorator(func: T):
@wraps(func)
def wrapped_function(*args, **kwargs):
providers_manager_instance = args[0]
if cache_name in providers_manager_instance._initialized_cache:
return
start_time = perf_counter()
logger.debug("Initializing Providers Manager[%s]", cache_name)
func(*args, **kwargs)
providers_manager_instance._initialized_cache[cache_name] = True
logger.debug(
"Initialization of Providers Manager[%s] took %.2f seconds",
cache_name,
perf_counter() - start_time,
)
return cast(T, wrapped_function)
return provider_info_cache_decorator
class ProvidersManager(LoggingMixin):
"""
Manages all provider packages.
This is a Singleton class. The first time it is
instantiated, it discovers all available providers in installed packages and
local source folders (if airflow is run from sources).
"""
_instance = None
resource_version = "0"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initializes the manager."""
super().__init__()
self._initialized_cache: dict[str, bool] = {}
# Keeps dict of providers keyed by module name
self._provider_dict: dict[str, ProviderInfo] = {}
# Keeps dict of hooks keyed by connection type
self._hooks_dict: dict[str, HookInfo] = {}
self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache()
# keeps mapping between connection_types and hook class, package they come from
self._hook_provider_dict: dict[str, HookClassProvider] = {}
# Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time
self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache()
# Keeps methods that should be used to add custom widgets tuple of keyed by name of the extra field
self._connection_form_widgets: dict[str, ConnectionFormWidgetInfo] = {}
# Customizations for javascript fields are kept here
self._field_behaviours: dict[str, dict] = {}
self._extra_link_class_name_set: set[str] = set()
self._logging_class_name_set: set[str] = set()
self._secrets_backend_class_name_set: set[str] = set()
self._api_auth_backend_module_names: set[str] = set()
self._trigger_info_set: set[TriggerInfo] = set()
self._provider_schema_validator = _create_provider_info_schema_validator()
self._customized_form_fields_schema_validator = (
_create_customized_form_field_behaviours_schema_validator()
)
@provider_info_cache("list")
def initialize_providers_list(self):
"""Lazy initialization of providers list."""
# Local source folders are loaded first. They should take precedence over the package ones for
# Development purpose. In production provider.yaml files are not present in the 'airflow" directory
# So there is no risk we are going to override package provider accidentally. This can only happen
# in case of local development
self._discover_all_airflow_builtin_providers_from_local_sources()
self._discover_all_providers_from_packages()
self._verify_all_providers_all_compatible()
self._provider_dict = OrderedDict(sorted(self._provider_dict.items()))
def _verify_all_providers_all_compatible(self):
from packaging import version as packaging_version
for provider_id, info in self._provider_dict.items():
min_version = MIN_PROVIDER_VERSIONS.get(provider_id)
if min_version:
if packaging_version.parse(min_version) > packaging_version.parse(info.version):
log.warning(
"The package %s is not compatible with this version of Airflow. "
"The package has version %s but the minimum supported version "
"of the package is %s",
provider_id,
info.version,
min_version,
)
@provider_info_cache("hooks")
def initialize_providers_hooks(self):
"""Lazy initialization of providers hooks."""
self.initialize_providers_list()
self._discover_hooks()
self._hook_provider_dict = OrderedDict(sorted(self._hook_provider_dict.items()))
@provider_info_cache("taskflow_decorators")
def initialize_providers_taskflow_decorator(self):
"""Lazy initialization of providers hooks."""
self.initialize_providers_list()
self._discover_taskflow_decorators()
@provider_info_cache("extra_links")
def initialize_providers_extra_links(self):
"""Lazy initialization of providers extra links."""
self.initialize_providers_list()
self._discover_extra_links()
@provider_info_cache("logging")
def initialize_providers_logging(self):
"""Lazy initialization of providers logging information."""
self.initialize_providers_list()
self._discover_logging()
@provider_info_cache("secrets_backends")
def initialize_providers_secrets_backends(self):
"""Lazy initialization of providers secrets_backends information."""
self.initialize_providers_list()
self._discover_secrets_backends()
@provider_info_cache("auth_backends")
def initialize_providers_auth_backends(self):
"""Lazy initialization of providers API auth_backends information."""
self.initialize_providers_list()
self._discover_auth_backends()
def _discover_all_providers_from_packages(self) -> None:
"""
Discover all providers by scanning packages installed.
The list of providers should be returned via the 'apache_airflow_provider'
entrypoint as a dictionary conforming to the 'airflow/provider_info.schema.json'
schema. Note that the schema is different at runtime than provider.yaml.schema.json.
The development version of provider schema is more strict and changes together with
the code. The runtime version is more relaxed (allows for additional properties)
and verifies only the subset of fields that are needed at runtime.
"""
for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
package_name = canonicalize_name(dist.metadata["name"])
if package_name in self._provider_dict:
continue
log.debug("Loading %s from package %s", entry_point, package_name)
version = dist.version
provider_info = entry_point.load()()
self._provider_schema_validator.validate(provider_info)
provider_info_package_name = provider_info["package-name"]
if package_name != provider_info_package_name:
raise Exception(
f"The package '{package_name}' from setuptools and "
f"{provider_info_package_name} do not match. Please make sure they are aligned"
)
if package_name not in self._provider_dict:
self._provider_dict[package_name] = ProviderInfo(version, provider_info, "package")
else:
log.warning(
"The provider for package '%s' could not be registered from because providers for that "
"package name have already been registered",
package_name,
)
def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None:
"""
Finds all built-in airflow providers if airflow is run from the local sources.
It finds `provider.yaml` files for all such providers and registers the providers using those.
This 'provider.yaml' scanning takes precedence over scanning packages installed
in case you have both sources and packages installed, the providers will be loaded from
the "airflow" sources rather than from the packages.
"""
try:
import airflow.providers
except ImportError:
log.info("You have no providers installed.")
return
try:
seen = set()
for path in airflow.providers.__path__: # type: ignore[attr-defined]
# The same path can appear in the __path__ twice, under non-normalized paths (ie.
# /path/to/repo/airflow/providers and /path/to/repo/./airflow/providers)
path = os.path.realpath(path)
if path in seen:
continue
seen.add(path)
self._add_provider_info_from_local_source_files_on_path(path)
except Exception as e:
log.warning("Error when loading 'provider.yaml' files from airflow sources: %s", e)
def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
"""
Finds all the provider.yaml files in the directory specified.
:param path: path where to look for provider.yaml files
"""
root_path = path
for folder, subdirs, files in os.walk(path, topdown=True):
for filename in fnmatch.filter(files, "provider.yaml"):
package_name = "apache-airflow-providers" + folder[len(root_path) :].replace(os.sep, "-")
self._add_provider_info_from_local_source_file(os.path.join(folder, filename), package_name)
subdirs[:] = []
def _add_provider_info_from_local_source_file(self, path, package_name) -> None:
"""
Parses found provider.yaml file and adds found provider to the dictionary.
:param path: full file path of the provider.yaml file
:param package_name: name of the package
"""
try:
log.debug("Loading %s from %s", package_name, path)
with open(path) as provider_yaml_file:
provider_info = yaml.safe_load(provider_yaml_file)
self._provider_schema_validator.validate(provider_info)
version = provider_info["versions"][0]
if package_name not in self._provider_dict:
self._provider_dict[package_name] = ProviderInfo(version, provider_info, "source")
else:
log.warning(
"The providers for package '%s' could not be registered because providers for that "
"package name have already been registered",
package_name,
)
except Exception as e:
log.warning("Error when loading '%s'", path, exc_info=e)
def _discover_hooks_from_connection_types(
self,
hook_class_names_registered: set[str],
already_registered_warning_connection_types: set[str],
package_name: str,
provider: ProviderInfo,
):
"""
Discover hooks from the "connection-types" property.
This is new, better method that replaces discovery from hook-class-names as it
allows to lazy import individual Hook classes when they are accessed.
The "connection-types" keeps information about both - connection type and class
name so we can discover all connection-types without importing the classes.
:param hook_class_names_registered: set of registered hook class names for this provider
:param already_registered_warning_connection_types: set of connections for which warning should be
printed in logs as they were already registered before
:param package_name:
:param provider:
:return:
"""
provider_uses_connection_types = False
connection_types = provider.data.get("connection-types")
if connection_types:
for connection_type_dict in connection_types:
connection_type = connection_type_dict["connection-type"]
hook_class_name = connection_type_dict["hook-class-name"]
hook_class_names_registered.add(hook_class_name)
already_registered = self._hook_provider_dict.get(connection_type)
if already_registered:
if already_registered.package_name != package_name:
already_registered_warning_connection_types.add(connection_type)
else:
log.warning(
"The connection type '%s' is already registered in the"
" package '%s' with different class names: '%s' and '%s'. ",
connection_type,
package_name,
already_registered.hook_class_name,
hook_class_name,
)
else:
self._hook_provider_dict[connection_type] = HookClassProvider(
hook_class_name=hook_class_name, package_name=package_name
)
# Defer importing hook to access time by setting import hook method as dict value
self._hooks_lazy_dict[connection_type] = functools.partial(
self._import_hook,
connection_type=connection_type,
provider_info=provider,
)
provider_uses_connection_types = True
return provider_uses_connection_types
def _discover_hooks_from_hook_class_names(
self,
hook_class_names_registered: set[str],
already_registered_warning_connection_types: set[str],
package_name: str,
provider: ProviderInfo,
provider_uses_connection_types: bool,
):
"""
Discover hooks from "hook-class-names' property.
This property is deprecated but we should support it in Airflow 2.
The hook-class-names array contained just Hook names without connection type,
therefore we need to import all those classes immediately to know which connection types
are supported. This makes it impossible to selectively only import those hooks that are used.
:param already_registered_warning_connection_types: list of connection hooks that we should warn
about when finished discovery
:param package_name: name of the provider package
:param provider: class that keeps information about version and details of the provider
:param provider_uses_connection_types: determines whether the provider uses "connection-types" new
form of passing connection types
:return:
"""
hook_class_names = provider.data.get("hook-class-names")
if hook_class_names:
for hook_class_name in hook_class_names:
if hook_class_name in hook_class_names_registered:
# Silently ignore the hook class - it's already marked for lazy-import by
# connection-types discovery
continue
hook_info = self._import_hook(
connection_type=None,
provider_info=provider,
hook_class_name=hook_class_name,
package_name=package_name,
)
if not hook_info:
# Problem why importing class - we ignore it. Log is written at import time
continue
already_registered = self._hook_provider_dict.get(hook_info.connection_type)
if already_registered:
if already_registered.package_name != package_name:
already_registered_warning_connection_types.add(hook_info.connection_type)
else:
if already_registered.hook_class_name != hook_class_name:
log.warning(
"The hook connection type '%s' is registered twice in the"
" package '%s' with different class names: '%s' and '%s'. "
" Please fix it!",
hook_info.connection_type,
package_name,
already_registered.hook_class_name,
hook_class_name,
)
else:
self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
hook_class_name=hook_class_name, package_name=package_name
)
self._hooks_lazy_dict[hook_info.connection_type] = hook_info
if not provider_uses_connection_types:
warnings.warn(
f"The provider {package_name} uses `hook-class-names` "
"property in provider-info and has no `connection-types` one. "
"The 'hook-class-names' property has been deprecated in favour "
"of 'connection-types' in Airflow 2.2. Use **both** in case you want to "
"have backwards compatibility with Airflow < 2.2",
DeprecationWarning,
)
for already_registered_connection_type in already_registered_warning_connection_types:
log.warning(
"The connection_type '%s' has been already registered by provider '%s.'",
already_registered_connection_type,
self._hook_provider_dict[already_registered_connection_type].package_name,
)
def _discover_hooks(self) -> None:
"""Retrieve all connections defined in the providers via Hooks."""
for package_name, provider in self._provider_dict.items():
duplicated_connection_types: set[str] = set()
hook_class_names_registered: set[str] = set()
provider_uses_connection_types = self._discover_hooks_from_connection_types(
hook_class_names_registered, duplicated_connection_types, package_name, provider
)
self._discover_hooks_from_hook_class_names(
hook_class_names_registered,
duplicated_connection_types,
package_name,
provider,
provider_uses_connection_types,
)
self._hook_provider_dict = OrderedDict(sorted(self._hook_provider_dict.items()))
@provider_info_cache("import_all_hooks")
def _import_info_from_all_hooks(self):
"""Force-import all hooks and initialize the connections/fields."""
# Retrieve all hooks to make sure that all of them are imported
_ = list(self._hooks_lazy_dict.values())
self._field_behaviours = OrderedDict(sorted(self._field_behaviours.items()))
# Widgets for connection forms are currently used in two places:
# 1. In the UI Connections, expected same order that it defined in Hook.
# 2. cli command - `airflow providers widgets` and expected that it in alphabetical order.
# It is not possible to recover original ordering after sorting,
# that the main reason why original sorting moved to cli part:
# self._connection_form_widgets = OrderedDict(sorted(self._connection_form_widgets.items()))
def _discover_taskflow_decorators(self) -> None:
for name, info in self._provider_dict.items():
for taskflow_decorator in info.data.get("task-decorators", []):
self._add_taskflow_decorator(
taskflow_decorator["name"], taskflow_decorator["class-name"], name
)
def _add_taskflow_decorator(self, name, decorator_class_name: str, provider_package: str) -> None:
if not _check_builtin_provider_prefix(provider_package, decorator_class_name):
return
if name in self._taskflow_decorators:
try:
existing = self._taskflow_decorators[name]
other_name = f"{existing.__module__}.{existing.__name__}"
except Exception:
# If problem importing, then get the value from the functools.partial
other_name = self._taskflow_decorators._raw_dict[name].args[0] # type: ignore[attr-defined]
log.warning(
"The taskflow decorator '%s' has been already registered (by %s).",
name,
other_name,
)
return
self._taskflow_decorators[name] = functools.partial(import_string, decorator_class_name)
@staticmethod
def _get_attr(obj: Any, attr_name: str):
"""Retrieve attributes of an object, or warn if not found."""
if not hasattr(obj, attr_name):
log.warning("The object '%s' is missing %s attribute and cannot be registered", obj, attr_name)
return None
return getattr(obj, attr_name)
def _import_hook(
self,
connection_type: str | None,
provider_info: ProviderInfo,
hook_class_name: str | None = None,
package_name: str | None = None,
) -> HookInfo | None:
"""
Import hook and retrieve hook information.
Either connection_type (for lazy loading) or hook_class_name must be set - but not both).
Only needs package_name if hook_class_name is passed (for lazy loading, package_name
is retrieved from _connection_type_class_provider_dict together with hook_class_name).
:param connection_type: type of the connection
:param hook_class_name: name of the hook class
:param package_name: provider package - only needed in case connection_type is missing
: return
"""
from wtforms import BooleanField, IntegerField, PasswordField, StringField
if connection_type is None and hook_class_name is None:
raise ValueError("Either connection_type or hook_class_name must be set")
if connection_type is not None and hook_class_name is not None:
raise ValueError(
f"Both connection_type ({connection_type} and "
f"hook_class_name {hook_class_name} are set. Only one should be set!"
)
if connection_type is not None:
class_provider = self._hook_provider_dict[connection_type]
package_name = class_provider.package_name
hook_class_name = class_provider.hook_class_name
else:
if not hook_class_name:
raise ValueError("Either connection_type or hook_class_name must be set")
if not package_name:
raise ValueError(
f"Provider package name is not set when hook_class_name ({hook_class_name}) is used"
)
allowed_field_classes = [IntegerField, PasswordField, StringField, BooleanField]
hook_class = _sanity_check(package_name, hook_class_name, provider_info)
if hook_class is None:
return None
try:
module, class_name = hook_class_name.rsplit(".", maxsplit=1)
# Do not use attr here. We want to check only direct class fields not those
# inherited from parent hook. This way we add form fields only once for the whole
# hierarchy and we add it only from the parent hook that provides those!
if "get_connection_form_widgets" in hook_class.__dict__:
widgets = hook_class.get_connection_form_widgets()
if widgets:
for widget in widgets.values():
if widget.field_class not in allowed_field_classes:
log.warning(
"The hook_class '%s' uses field of unsupported class '%s'. "
"Only '%s' field classes are supported",
hook_class_name,
widget.field_class,
allowed_field_classes,
)
return None
self._add_widgets(package_name, hook_class, widgets)
if "get_ui_field_behaviour" in hook_class.__dict__:
field_behaviours = hook_class.get_ui_field_behaviour()
if field_behaviours:
self._add_customized_fields(package_name, hook_class, field_behaviours)
except Exception as e:
log.warning(
"Exception when importing '%s' from '%s' package: %s",
hook_class_name,
package_name,
e,
)
return None
hook_connection_type = self._get_attr(hook_class, "conn_type")
if connection_type:
if hook_connection_type != connection_type:
log.warning(
"Inconsistency! The hook class '%s' declares connection type '%s'"
" but it is added by provider '%s' as connection_type '%s' in provider info. "
"This should be fixed!",
hook_class,
hook_connection_type,
package_name,
connection_type,
)
connection_type = hook_connection_type
connection_id_attribute_name: str = self._get_attr(hook_class, "conn_name_attr")
hook_name: str = self._get_attr(hook_class, "hook_name")
if not connection_type or not connection_id_attribute_name or not hook_name:
log.warning(
"The hook misses one of the key attributes: "
"conn_type: %s, conn_id_attribute_name: %s, hook_name: %s",
connection_type,
connection_id_attribute_name,
hook_name,
)
return None
return HookInfo(
hook_class_name=hook_class_name,
connection_id_attribute_name=connection_id_attribute_name,
package_name=package_name,
hook_name=hook_name,
connection_type=connection_type,
connection_testable=hasattr(hook_class, "test_connection"),
)
def _add_widgets(self, package_name: str, hook_class: type, widgets: dict[str, Any]):
conn_type = hook_class.conn_type # type: ignore
for field_identifier, field in widgets.items():
if field_identifier.startswith("extra__"):
prefixed_field_name = field_identifier
else:
prefixed_field_name = f"extra__{conn_type}__{field_identifier}"
if prefixed_field_name in self._connection_form_widgets:
log.warning(
"The field %s from class %s has already been added by another provider. Ignoring it.",
field_identifier,
hook_class.__name__,
)
# In case of inherited hooks this might be happening several times
continue
self._connection_form_widgets[prefixed_field_name] = ConnectionFormWidgetInfo(
hook_class.__name__, package_name, field, field_identifier
)
def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: dict):
try:
connection_type = getattr(hook_class, "conn_type")
self._customized_form_fields_schema_validator.validate(customized_fields)
if connection_type:
customized_fields = _ensure_prefix_for_placeholders(customized_fields, connection_type)
if connection_type in self._field_behaviours:
log.warning(
"The connection_type %s from package %s and class %s has already been added "
"by another provider. Ignoring it.",
connection_type,
package_name,
hook_class.__name__,
)
return
self._field_behaviours[connection_type] = customized_fields
except Exception as e:
log.warning(
"Error when loading customized fields from package '%s' hook class '%s': %s",
package_name,
hook_class.__name__,
e,
)
def _discover_extra_links(self) -> None:
"""Retrieves all extra links defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("extra-links"):
for extra_link_class_name in provider.data["extra-links"]:
if _sanity_check(provider_package, extra_link_class_name, provider):
self._extra_link_class_name_set.add(extra_link_class_name)
def _discover_logging(self) -> None:
"""Retrieve all logging defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("logging"):
for logging_class_name in provider.data["logging"]:
if _sanity_check(provider_package, logging_class_name, provider):
self._logging_class_name_set.add(logging_class_name)
def _discover_secrets_backends(self) -> None:
"""Retrieve all secrets backends defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("secrets-backends"):
for secrets_backends_class_name in provider.data["secrets-backends"]:
if _sanity_check(provider_package, secrets_backends_class_name, provider):
self._secrets_backend_class_name_set.add(secrets_backends_class_name)
def _discover_auth_backends(self) -> None:
"""Retrieve all API auth backends defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("auth-backends"):
for auth_backend_module_name in provider.data["auth-backends"]:
if _sanity_check(provider_package, auth_backend_module_name + ".init_app", provider):
self._api_auth_backend_module_names.add(auth_backend_module_name)
@provider_info_cache("triggers")
def initialize_providers_triggers(self):
"""Initialization of providers triggers."""
self.initialize_providers_list()
for provider_package, provider in self._provider_dict.items():
for trigger in provider.data.get("triggers", []):
for trigger_class_name in trigger.get("class-names"):
self._trigger_info_set.add(
TriggerInfo(
package_name=provider_package,
trigger_class_name=trigger_class_name,
integration_name=trigger.get("integration-name", ""),
)
)
@property
def trigger(self) -> list[TriggerInfo]:
"""Returns information about available providers trigger class."""
self.initialize_providers_triggers()
return sorted(self._trigger_info_set, key=lambda x: x.package_name)
@property
def providers(self) -> dict[str, ProviderInfo]:
"""Returns information about available providers."""
self.initialize_providers_list()
return self._provider_dict
@property
def hooks(self) -> MutableMapping[str, HookInfo | None]:
"""
Return dictionary of connection_type-to-hook mapping.
Note that the dict can contain None values if a hook discovered cannot be imported!
"""
self.initialize_providers_hooks()
# When we return hooks here it will only be used to retrieve hook information
return self._hooks_lazy_dict
@property
def taskflow_decorators(self) -> dict[str, TaskDecorator]:
self.initialize_providers_taskflow_decorator()
return self._taskflow_decorators
@property
def extra_links_class_names(self) -> list[str]:
"""Returns set of extra link class names."""
self.initialize_providers_extra_links()
return sorted(self._extra_link_class_name_set)
@property
def connection_form_widgets(self) -> dict[str, ConnectionFormWidgetInfo]:
"""
Returns widgets for connection forms.
Dictionary keys in the same order that it defined in Hook.
"""
self.initialize_providers_hooks()
self._import_info_from_all_hooks()
return self._connection_form_widgets
@property
def field_behaviours(self) -> dict[str, dict]:
"""Returns dictionary with field behaviours for connection types."""
self.initialize_providers_hooks()
self._import_info_from_all_hooks()
return self._field_behaviours
@property
def logging_class_names(self) -> list[str]:
"""Returns set of log task handlers class names."""
self.initialize_providers_logging()
return sorted(self._logging_class_name_set)
@property
def secrets_backend_class_names(self) -> list[str]:
"""Returns set of secret backend class names."""
self.initialize_providers_secrets_backends()
return sorted(self._secrets_backend_class_name_set)
@property
def auth_backend_module_names(self) -> list[str]:
"""Returns set of API auth backend class names."""
self.initialize_providers_auth_backends()
return sorted(self._api_auth_backend_module_names)