blob: 87c37f66406d8a1b7d903acfa92717b5313dc185 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Module for provider's custom Sphinx extensions that will be loaded in all providers' documentation."""
from __future__ import annotations
import ast
import os
from pathlib import Path
from typing import Any, Iterable
import yaml
# No stub exists for docutils.parsers.rst.directives. See https://github.com/python/typeshed/issues/5755.
from provider_yaml_utils import get_provider_yaml_paths
from docs.exts.operators_and_hooks_ref import (
DEFAULT_HEADER_SEPARATOR,
BaseJinjaReferenceDirective,
_render_template,
)
def _get_module_class_registry(
module_filepath: str, class_extras: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Extracts classes and its information from a Python module file.
The function parses the specified module file and registers all classes.
The registry for each class includes the module filename, methods, base classes
and any additional class extras provided.
:param module_filepath: The file path of the module.
:param class_extras: Additional information to include in each class's registry.
:return: A dictionary with class names as keys and their corresponding information.
"""
with open(module_filepath) as file:
ast_obj = ast.parse(file.read())
module_class_registry = {
node.name: {
"module_filepath": module_filepath,
"methods": {n.name for n in ast.walk(node) if isinstance(n, ast.FunctionDef)},
"base_classes": [b.id for b in node.bases if isinstance(b, ast.Name)],
**class_extras,
}
for node in ast_obj.body
if isinstance(node, ast.ClassDef)
}
return module_class_registry
def _has_method(
class_name: str, method_names: Iterable[str], class_registry: dict[str, dict[str, Any]]
) -> bool:
"""Determines if a class or its bases in the registry have any of the specified methods.
:param class_name: The name of the class to check.
:param method_names: A list of names of methods to search for.
:param class_registry: A dictionary representing the class registry, where each key is a class name
and the value is its metadata.
:return: True if any of the specified methods are found in the class or its base classes; False otherwise.
Example:
>>> example_class_registry = {
... "MyClass": {"methods": {"foo", "bar"}, "base_classes": ["BaseClass"]},
... "BaseClass": {"methods": {"base_foo"}, "base_classes": []},
... }
>>> _has_method("MyClass", ["foo"], example_class_registry)
True
>>> _has_method("MyClass", ["base_foo"], example_class_registry)
True
>>> _has_method("MyClass", ["not_a_method"], example_class_registry)
False
"""
if class_name in class_registry:
if any(method in class_registry[class_name]["methods"] for method in method_names):
return True
for base_name in class_registry[class_name]["base_classes"]:
if _has_method(base_name, method_names, class_registry):
return True
return False
def _get_providers_class_registry() -> dict[str, dict[str, Any]]:
"""Builds a registry of classes from YAML configuration files.
This function scans through YAML configuration files to build a registry of classes.
It parses each YAML file to get the provider's name and registers classes from Python
module files within the provider's directory, excluding '__init__.py'.
:return: A dictionary with provider names as keys and a dictionary of classes as values.
"""
class_registry = {}
for provider_yaml_path in get_provider_yaml_paths():
provider_yaml_content = yaml.safe_load(Path(provider_yaml_path).read_text())
for root, _, file_names in os.walk(Path(provider_yaml_path).parent):
for file_name in file_names:
module_filepath = f"{os.path.relpath(root)}/{file_name}"
if not module_filepath.endswith(".py") or module_filepath == "__init__.py":
continue
module_registry = _get_module_class_registry(
module_filepath=module_filepath,
class_extras={"provider_name": provider_yaml_content["package-name"]},
)
class_registry.update(module_registry)
return class_registry
def _render_openlineage_supported_classes_content():
openlineage_operator_methods = ("get_openlineage_facets_on_complete", "get_openlineage_facets_on_start")
openlineage_db_hook_methods = (
"get_openlineage_database_info",
"get_openlineage_database_specific_lineage",
)
class_registry = _get_providers_class_registry()
# These excluded classes will be included in docs directly
class_registry.pop("DbApiHook")
class_registry.pop("SQLExecuteQueryOperator")
providers: dict[str, dict[str, list[str]]] = {}
db_hooks: list[tuple[str, str]] = []
for class_name, info in class_registry.items():
if class_name.startswith("_"):
continue
module_name = info["module_filepath"].replace("/", ".").replace(".py", "").lstrip(".")
class_path = f"{module_name}.{class_name}"
provider_entry = providers.setdefault(info["provider_name"], {"operators": []})
if class_name.lower().endswith("operator"):
if _has_method(
class_name=class_name,
method_names=openlineage_operator_methods,
class_registry=class_registry,
):
provider_entry["operators"].append(class_path)
elif class_name.lower().endswith("hook"):
if _has_method(
class_name=class_name,
method_names=openlineage_db_hook_methods,
class_registry=class_registry,
):
db_type = class_name.replace("SqlApiHook", "").replace("Hook", "")
db_hooks.append((db_type, class_path))
providers = {
provider: {key: sorted(set(value), key=lambda x: x.split(".")[-1]) for key, value in details.items()}
for provider, details in sorted(providers.items())
if any(details.values())
}
db_hooks = sorted({db_type: hook for db_type, hook in db_hooks}.items(), key=lambda x: x[0])
return _render_template(
"openlineage.rst.jinja2",
providers=providers,
db_hooks=db_hooks,
)
class OpenLineageSupportedClassesDirective(BaseJinjaReferenceDirective):
"""Generate list of classes supporting OpenLineage"""
def render_content(self, *, tags: set[str] | None, header_separator: str = DEFAULT_HEADER_SEPARATOR):
return _render_openlineage_supported_classes_content()
def setup(app):
"""Setup plugin"""
app.add_directive("airflow-providers-openlineage-supported-classes", OpenLineageSupportedClassesDirective)
return {"parallel_read_safe": True, "parallel_write_safe": True}