| #!/usr/bin/env python3 |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ |
| Airflow Registry Parameter & Module Extractor |
| |
| Discovers provider modules (operators, hooks, sensors, triggers, etc.) at runtime |
| and extracts constructor/function parameters via MRO or signature inspection. |
| Produces both modules.json (the full module catalog) and per-provider |
| parameters.json files. |
| |
| Must be run inside breeze where all providers are installed. |
| |
| Usage: |
| breeze run python dev/registry/extract_parameters.py |
| |
| Output: |
| - dev/registry/modules.json (+ registry/src/_data/modules.json on host) |
| - dev/registry/output/versions/{provider_id}/{version}/parameters.json |
| - registry/src/_data/versions/{provider_id}/{version}/parameters.json |
| - dev/registry/runtime_modules.json (debug stats) |
| """ |
| |
| from __future__ import annotations |
| |
| import argparse |
| import concurrent.futures |
| import importlib |
| import inspect |
| import json |
| import logging |
| import re |
| import sys |
| import typing |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from datetime import datetime, timezone |
| from pathlib import Path |
| |
| import yaml |
| from extract_metadata import fetch_provider_inventory, read_inventory |
| from registry_contract_models import validate_modules_catalog, validate_provider_parameters |
| from registry_tools.types import BASE_CLASS_IMPORTS, CLASS_LEVEL_SECTIONS, MODULE_LEVEL_SECTIONS |
| |
| AIRFLOW_ROOT = Path(__file__).parent.parent.parent |
| SCRIPT_DIR = Path(__file__).parent |
| PROVIDERS_DIR = AIRFLOW_ROOT / "providers" |
| |
| PROVIDERS_JSON_CANDIDATES = [ |
| SCRIPT_DIR / "providers.json", |
| AIRFLOW_ROOT / "registry" / "src" / "_data" / "providers.json", |
| ] |
| |
| # Inside breeze, write to dev/registry/output/ (mounted). |
| # On host, also write to the registry data directory. |
| OUTPUT_DIRS = [ |
| SCRIPT_DIR / "output", |
| AIRFLOW_ROOT / "registry" / "src" / "_data", |
| ] |
| |
| |
| @dataclass |
| class Module: |
| """A discovered provider module (operator, hook, sensor, etc.).""" |
| |
| id: str |
| name: str # Class name (e.g., SnowflakeOperator) |
| type: str # operator, hook, sensor, trigger, transfer, etc. |
| import_path: str # Full import path to the class |
| module_path: str # Module file path |
| short_description: str |
| docs_url: str |
| source_url: str |
| category: str |
| provider_id: str |
| provider_name: str |
| |
| |
| def get_category(integration_name: str) -> str: |
| """Slugify an integration name into a category ID.""" |
| cat_id = integration_name.lower().replace(" ", "-").replace("(", "").replace(")", "") |
| return re.sub(r"[^a-z0-9-]", "", cat_id) |
| |
| |
| def format_annotation(annotation: type, _depth: int = 0) -> str | None: |
| """Convert a type annotation to a human-readable string.""" |
| if _depth > 5: |
| return str(annotation) |
| |
| if annotation is inspect.Parameter.empty: |
| return None |
| |
| if annotation is type(None): |
| return "None" |
| |
| origin = getattr(annotation, "__origin__", None) |
| |
| # typing.Union (includes X | Y on 3.10+) |
| if origin is typing.Union: |
| args = typing.get_args(annotation) |
| parts = [format_annotation(a, _depth + 1) for a in args] |
| return " | ".join(p for p in parts if p) |
| |
| if origin is not None: |
| args = typing.get_args(annotation) |
| origin_name = getattr(origin, "__name__", str(origin)) |
| if args: |
| arg_strs = [format_annotation(a, _depth + 1) or "Any" for a in args] |
| return f"{origin_name}[{', '.join(arg_strs)}]" |
| return origin_name |
| |
| if hasattr(annotation, "__name__"): |
| return annotation.__name__ |
| |
| s = str(annotation) |
| s = re.sub(r"\btyping\.", "", s) |
| s = re.sub(r"\bcollections\.abc\.", "", s) |
| return s |
| |
| |
| def format_default(default: object) -> object: |
| """Convert a default value to a JSON-serializable representation.""" |
| if default is inspect.Parameter.empty: |
| return None |
| |
| if default is None: |
| return None |
| |
| if isinstance(default, (str, int, float, bool)): |
| return default |
| |
| if isinstance(default, (list, tuple, dict)): |
| try: |
| json.dumps(default) |
| return default |
| except (TypeError, ValueError): |
| pass |
| |
| try: |
| return str(default) |
| except Exception: |
| return repr(default) |
| |
| |
| def get_params_from_signature(sig: inspect.Signature, qualified_origin: str) -> dict[str, dict]: |
| """Convert an inspect.Signature into parameter metadata.""" |
| params: dict[str, dict] = {} |
| |
| for name, param in sig.parameters.items(): |
| if name in ("self", "cls"): |
| continue |
| if name.startswith("_"): |
| continue |
| if param.kind in (param.VAR_KEYWORD, param.VAR_POSITIONAL): |
| continue |
| |
| params[name] = { |
| "name": name, |
| "type": format_annotation(param.annotation), |
| "default": format_default(param.default), |
| "required": param.default is inspect.Parameter.empty, |
| "origin": qualified_origin, |
| } |
| |
| return params |
| |
| |
| def get_params_from_class(cls: type) -> dict[str, dict]: |
| """ |
| Extract all __init__ parameters by walking the MRO in reverse |
| so child class overrides parent for the same parameter name. |
| Records the full qualified origin (module.ClassName) for each param. |
| """ |
| params: dict[str, dict] = {} |
| |
| for klass in reversed(cls.__mro__): |
| if klass is object: |
| continue |
| |
| init = klass.__dict__.get("__init__") |
| if init is None: |
| continue |
| |
| try: |
| sig = inspect.signature(init) |
| except (TypeError, ValueError): |
| continue |
| |
| qualified_origin = f"{klass.__module__}.{klass.__qualname__}" |
| params.update(get_params_from_signature(sig, qualified_origin)) |
| |
| return params |
| |
| |
| def get_mro_chain(cls: type) -> list[str]: |
| """Return the full MRO as a list of qualified class names.""" |
| return [f"{k.__module__}.{k.__qualname__}" for k in cls.__mro__ if k is not object] |
| |
| |
| def parse_param_descriptions(doc: str) -> dict[str, str]: |
| """Parse ``:param name: description`` entries from a docstring.""" |
| descriptions: dict[str, str] = {} |
| if not doc: |
| return descriptions |
| |
| for match in re.finditer( |
| r":param\s+(\w+):\s*(.+?)(?=\n\s*:\w|\n\s*\.\.|$)", |
| doc, |
| re.DOTALL, |
| ): |
| name = match.group(1) |
| desc = match.group(2).strip() |
| desc = re.sub(r"\s+", " ", desc) |
| if name not in descriptions: |
| descriptions[name] = desc |
| |
| return descriptions |
| |
| |
| def parse_docstring_params(cls: type) -> dict[str, str]: |
| """ |
| Parse :param name: description from class and ancestor docstrings. |
| Child class descriptions take priority. |
| """ |
| descriptions: dict[str, str] = {} |
| |
| for klass in cls.__mro__: |
| if klass is object: |
| continue |
| for name, desc in parse_param_descriptions(getattr(klass, "__doc__", None) or "").items(): |
| if name not in descriptions: |
| descriptions[name] = desc |
| |
| return descriptions |
| |
| |
| def extract_class_params(cls: type) -> tuple[list[str], list[dict]]: |
| """ |
| Extract parameter list for a class, merging signature + docstrings. |
| Only includes params originating from provider classes (airflow.providers.*). |
| Returns (mro_chain, filtered_params). |
| """ |
| params = get_params_from_class(cls) |
| descriptions = parse_docstring_params(cls) |
| |
| for name, param in params.items(): |
| if name in descriptions: |
| param["description"] = descriptions[name] |
| else: |
| param["description"] = None |
| |
| provider_params = [p for p in params.values() if p["origin"].startswith("airflow.providers.")] |
| mro = get_mro_chain(cls) |
| |
| return mro, provider_params |
| |
| |
| def extract_callable_params(func: typing.Callable[..., typing.Any]) -> tuple[list[str], list[dict]]: |
| """Extract parameter metadata from a callable signature and docstring.""" |
| try: |
| sig = inspect.signature(func) |
| except (TypeError, ValueError) as e: |
| raise TypeError(f"Could not inspect callable signature: {e}") from e |
| |
| qualified_origin = f"{func.__module__}.{func.__qualname__}" |
| params = get_params_from_signature(sig, qualified_origin) |
| descriptions = parse_param_descriptions(getattr(func, "__doc__", None) or "") |
| |
| for name, param in params.items(): |
| param["description"] = descriptions.get(name) |
| |
| provider_params = [p for p in params.values() if p["origin"].startswith("airflow.providers.")] |
| return [], provider_params |
| |
| |
| def extract_params(obj: object) -> tuple[list[str], list[dict]]: |
| """Extract parameter metadata from either a class or a callable.""" |
| if inspect.isclass(obj): |
| return extract_class_params(obj) |
| if callable(obj): |
| return extract_callable_params(typing.cast("typing.Callable[..., typing.Any]", obj)) |
| raise TypeError(f"Unsupported import type: {type(obj)!r}") |
| |
| |
| def import_symbol(import_path: str) -> object | None: |
| """Import a symbol from its full dotted path.""" |
| parts = import_path.rsplit(".", 1) |
| if len(parts) != 2: |
| return None |
| |
| module_path, symbol_name = parts |
| try: |
| module = importlib.import_module(module_path) |
| return getattr(module, symbol_name, None) |
| except Exception as e: |
| print(f" WARN failed to import {import_path}: {e}") |
| return None |
| |
| |
| def find_json(candidates: list[Path], name: str) -> Path: |
| """Find first existing JSON file from candidates list.""" |
| for candidate in candidates: |
| if candidate.exists(): |
| return candidate |
| print(f"ERROR: {name} not found. Searched:") |
| for c in candidates: |
| print(f" - {c}") |
| print(f"\nCopy {name} to dev/registry/ or run extract_metadata.py first.") |
| sys.exit(1) |
| |
| |
| log = logging.getLogger(__name__) |
| |
| |
| def load_base_classes() -> dict[str, type]: |
| """Import base classes for issubclass checks. |
| |
| Returns a mapping of type name -> base class (e.g. "sensor" -> BaseSensorOperator). |
| """ |
| base_classes: dict[str, type] = {} |
| for type_name, import_path in BASE_CLASS_IMPORTS: |
| module_path, class_name = import_path.rsplit(".", 1) |
| try: |
| mod = importlib.import_module(module_path) |
| base_classes[type_name] = getattr(mod, class_name) |
| except Exception: |
| log.warning("Could not import base class %s", import_path) |
| return base_classes |
| |
| |
| def _should_skip_class(name: str) -> bool: |
| """Return True if a class name should be excluded from discovery.""" |
| if name.startswith("_"): |
| return True |
| if name.startswith("Base"): |
| return True |
| if "Abstract" in name or "Mixin" in name: |
| return True |
| return False |
| |
| |
| def _get_first_docstring_line(cls: type) -> str | None: |
| """Return the first non-empty line of a class docstring, or None.""" |
| doc = getattr(cls, "__doc__", None) |
| if not doc: |
| return None |
| for line in doc.strip().splitlines(): |
| stripped = line.strip() |
| if stripped: |
| return stripped |
| return None |
| |
| |
| def _get_source_line(cls: type) -> int | None: |
| """Return the source line number for a class, or None if unavailable.""" |
| try: |
| return inspect.getsourcelines(cls)[1] |
| except (OSError, TypeError): |
| return None |
| |
| |
| def discover_classes_from_provider( |
| provider_yaml_path: Path, |
| base_classes: dict[str, type], |
| inventory: dict[str, str] | None = None, |
| version: str = "", |
| ) -> list[dict]: |
| """Discover classes from a single provider by importing its modules at runtime. |
| |
| Reads the provider.yaml to find which modules/classes to inspect, imports them, |
| and returns metadata for each discovered class with all 11 Module fields. |
| """ |
| with open(provider_yaml_path) as f: |
| provider_yaml = yaml.safe_load(f) |
| |
| provider_id = provider_yaml.get("package-name", "").replace("apache-airflow-providers-", "") |
| if not provider_id: |
| return [] |
| |
| provider_name = provider_yaml.get("name", provider_id.replace("-", " ").title()) |
| provider_rel_path = provider_yaml_path.parent.relative_to(PROVIDERS_DIR) |
| tag = f"providers-{provider_id}/{version}" if version else "main" |
| base_docs_url = f"https://airflow.apache.org/docs/apache-airflow-providers-{provider_id}/stable" |
| base_source_url = f"https://github.com/apache/airflow/blob/{tag}/providers/{provider_rel_path}/src" |
| |
| # Build integration-name lookup for module-level sections |
| # Maps (section_name, module_path) -> integration_name |
| integration_by_module: dict[tuple[str, str], str] = {} |
| for section_name in list(MODULE_LEVEL_SECTIONS) + ["bundles"]: |
| for group in provider_yaml.get(section_name, []): |
| integration = group.get("integration-name", "") |
| for mp in group.get("python-modules", []): |
| integration_by_module[(section_name, mp)] = integration |
| |
| def resolve_docs_url(full_class_path: str, module_path: str) -> str: |
| """Look up docs URL from inventory, falling back to manual construction.""" |
| if inventory and full_class_path in inventory: |
| return f"{base_docs_url}/{inventory[full_class_path]}" |
| api_ref_path = module_path.replace(".", "/") |
| return f"{base_docs_url}/_api/{api_ref_path}/index.html#{full_class_path}" |
| |
| def make_source_url(cls: type, module_path: str) -> str: |
| """Construct a GitHub source URL with line number when available.""" |
| url = f"{base_source_url}/{module_path.replace('.', '/')}.py" |
| line = _get_source_line(cls) |
| if line: |
| url += f"#L{line}" |
| return url |
| |
| def make_entry( |
| cls_or_obj: type, |
| name: str, |
| module_type: str, |
| import_path: str, |
| module_path: str, |
| integration: str = "", |
| category: str = "", |
| transfer_desc: str | None = None, |
| ) -> dict: |
| """Build a full module entry dict with all 11 fields.""" |
| module_name = module_path.split(".")[-1] |
| docstring = _get_first_docstring_line(cls_or_obj) |
| short_desc = docstring or transfer_desc or f"{integration} {module_type}".strip() |
| |
| return { |
| "id": f"{provider_id}-{module_name}-{name}", |
| "name": name, |
| "type": module_type, |
| "import_path": import_path, |
| "module_path": module_path, |
| "short_description": short_desc, |
| "docs_url": resolve_docs_url(import_path, module_path), |
| "source_url": make_source_url(cls_or_obj, module_path), |
| "category": category or get_category(integration), |
| "provider_id": provider_id, |
| "provider_name": provider_name, |
| } |
| |
| discovered: list[dict] = [] |
| |
| # --- Module-level sections (operators, hooks, sensors, triggers, bundles) --- |
| for section_name, module_type in MODULE_LEVEL_SECTIONS.items(): |
| expected_base = base_classes.get(module_type) |
| for group in provider_yaml.get(section_name, []): |
| integration = group.get("integration-name", "") |
| category = get_category(integration) |
| for module_path in group.get("python-modules", []): |
| try: |
| mod = importlib.import_module(module_path) |
| except Exception: |
| log.warning("Could not import module %s", module_path) |
| continue |
| |
| for name, cls in inspect.getmembers(mod, inspect.isclass): |
| if cls.__module__ != mod.__name__: |
| continue |
| if _should_skip_class(name): |
| continue |
| if expected_base and not issubclass(cls, expected_base): |
| continue |
| |
| discovered.append( |
| make_entry( |
| cls, |
| name, |
| module_type, |
| f"{module_path}.{name}", |
| module_path, |
| integration, |
| category, |
| ) |
| ) |
| |
| # --- Transfers (module-level, singular python-module key) --- |
| transfer_base = base_classes.get("operator") |
| for transfer in provider_yaml.get("transfers", []): |
| module_path = transfer.get("python-module", "") |
| if not module_path: |
| continue |
| source = transfer.get("source-integration-name", "") |
| target = transfer.get("target-integration-name", "") |
| transfer_desc = f"Transfer from {source} to {target}" if source and target else None |
| category = get_category(source) if source else "" |
| |
| try: |
| mod = importlib.import_module(module_path) |
| except Exception: |
| log.warning("Could not import module %s", module_path) |
| continue |
| |
| for name, cls in inspect.getmembers(mod, inspect.isclass): |
| if cls.__module__ != mod.__name__: |
| continue |
| if _should_skip_class(name): |
| continue |
| if transfer_base and not issubclass(cls, transfer_base): |
| continue |
| |
| discovered.append( |
| make_entry( |
| cls, |
| name, |
| "transfer", |
| f"{module_path}.{name}", |
| module_path, |
| source, |
| category, |
| transfer_desc, |
| ) |
| ) |
| |
| # --- Class-level sections (notifications, secrets-backends, logging, executors) --- |
| for section_name, module_type in CLASS_LEVEL_SECTIONS.items(): |
| for class_path in provider_yaml.get(section_name, []): |
| if not class_path or not isinstance(class_path, str): |
| continue |
| parts = class_path.rsplit(".", 1) |
| if len(parts) != 2: |
| continue |
| module_path, class_name = parts |
| try: |
| mod = importlib.import_module(module_path) |
| candidate = getattr(mod, class_name, None) |
| except Exception: |
| log.warning("Could not import %s", class_path) |
| continue |
| if candidate is None or not inspect.isclass(candidate): |
| log.warning("%s is not a class", class_path) |
| continue |
| cls = typing.cast("type[typing.Any]", candidate) |
| |
| # Use section name as category for class-level entries |
| category_map = { |
| "notifications": "notifications", |
| "secrets-backends": "secrets", |
| "logging": "logging", |
| "executors": "executors", |
| } |
| |
| discovered.append( |
| make_entry( |
| cls, |
| class_name, |
| module_type, |
| class_path, |
| module_path, |
| category=category_map.get(section_name, section_name), |
| ) |
| ) |
| |
| # --- Task decorators (class-name key in each entry) --- |
| for decorator in provider_yaml.get("task-decorators", []): |
| class_path = decorator.get("class-name", "") |
| decorator_name = decorator.get("name", "") |
| if not class_path: |
| continue |
| parts = class_path.rsplit(".", 1) |
| if len(parts) != 2: |
| continue |
| module_path, func_name = parts |
| try: |
| mod = importlib.import_module(module_path) |
| obj = getattr(mod, func_name, None) |
| except Exception: |
| log.warning("Could not import %s", class_path) |
| continue |
| if obj is None: |
| continue |
| |
| display_name = f"@task.{decorator_name}" if decorator_name else func_name |
| docstring = _get_first_docstring_line(obj) if hasattr(obj, "__doc__") else None |
| short_desc = docstring or f"Task decorator for {decorator_name or func_name}" |
| |
| discovered.append( |
| { |
| "id": f"{provider_id}-decorator-{decorator_name or func_name}", |
| "name": display_name, |
| "type": "decorator", |
| "import_path": class_path, |
| "module_path": module_path, |
| "short_description": short_desc, |
| "docs_url": resolve_docs_url(class_path, module_path), |
| "source_url": f"{base_source_url}/{module_path.replace('.', '/')}.py", |
| "category": "decorators", |
| "provider_id": provider_id, |
| "provider_name": provider_name, |
| } |
| ) |
| |
| return discovered |
| |
| |
| def compare_with_ast( |
| runtime_classes: list[dict], |
| modules_json_path: Path, |
| ) -> dict: |
| """Compare runtime-discovered classes against AST-produced modules.json. |
| |
| Returns a stats dict with counts of phantoms, misses, and type mismatches. |
| """ |
| with open(modules_json_path) as f: |
| ast_data = json.load(f) |
| |
| ast_modules = ast_data.get("modules", []) |
| |
| ast_by_path: dict[str, dict] = {} |
| for m in ast_modules: |
| path = m.get("import_path", "") |
| if path: |
| ast_by_path[path] = m |
| |
| runtime_by_path: dict[str, dict] = {} |
| for r in runtime_classes: |
| path = r.get("import_path", "") |
| if path: |
| runtime_by_path[path] = r |
| |
| ast_paths = set(ast_by_path) |
| runtime_paths = set(runtime_by_path) |
| |
| phantoms = sorted(ast_paths - runtime_paths) |
| misses = sorted(runtime_paths - ast_paths) |
| common = ast_paths & runtime_paths |
| |
| type_mismatches = [] |
| for path in sorted(common): |
| ast_type = ast_by_path[path].get("type", "") |
| runtime_type = runtime_by_path[path].get("type", "") |
| if ast_type != runtime_type: |
| type_mismatches.append( |
| { |
| "import_path": path, |
| "ast_type": ast_type, |
| "runtime_type": runtime_type, |
| } |
| ) |
| |
| # Print comparison table |
| print("\n" + "=" * 60) |
| print("Runtime vs AST Comparison") |
| print("=" * 60) |
| print(f" AST classes: {len(ast_paths)}") |
| print(f" Runtime classes: {len(runtime_paths)}") |
| print(f" In common: {len(common)}") |
| print(f" AST phantoms: {len(phantoms)} (in AST, not runtime)") |
| print(f" AST misses: {len(misses)} (in runtime, not AST)") |
| print(f" Type mismatches: {len(type_mismatches)}") |
| |
| if phantoms: |
| print(f"\nAST Phantoms ({len(phantoms)}):") |
| for p in phantoms[:20]: |
| ast_type = ast_by_path[p].get("type", "?") |
| print(f" [{ast_type}] {p}") |
| if len(phantoms) > 20: |
| print(f" ... and {len(phantoms) - 20} more") |
| |
| if misses: |
| print(f"\nAST Misses ({len(misses)}):") |
| for p in misses[:20]: |
| rt_type = runtime_by_path[p].get("type", "?") |
| print(f" [{rt_type}] {p}") |
| if len(misses) > 20: |
| print(f" ... and {len(misses) - 20} more") |
| |
| if type_mismatches: |
| print(f"\nType Mismatches ({len(type_mismatches)}):") |
| for m in type_mismatches[:20]: |
| print(f" {m['import_path']}: AST={m['ast_type']} Runtime={m['runtime_type']}") |
| if len(type_mismatches) > 20: |
| print(f" ... and {len(type_mismatches) - 20} more") |
| |
| print("=" * 60) |
| |
| return { |
| "ast_phantoms": len(phantoms), |
| "ast_misses": len(misses), |
| "type_mismatches": len(type_mismatches), |
| "phantom_paths": phantoms, |
| "miss_paths": misses, |
| "mismatch_details": type_mismatches, |
| } |
| |
| |
| def _extract_params_from_modules( |
| modules: list[dict], |
| ) -> tuple[dict[str, dict[str, dict]], dict[str, str], int, int, int]: |
| """Extract parameters from a list of module dicts. |
| |
| Returns (provider_classes, provider_names, total_processed, total_failed, total_params). |
| """ |
| provider_classes: dict[str, dict[str, dict]] = defaultdict(dict) |
| provider_names: dict[str, str] = {} |
| total_processed = 0 |
| total_failed = 0 |
| total_params = 0 |
| |
| for i, module in enumerate(modules, 1): |
| import_path = module.get("import_path", "") |
| provider_id = module.get("provider_id", "") |
| provider_name = module.get("provider_name", module.get("provider_id", "")) |
| module_type = module.get("type", "") |
| class_name = module.get("name", "") |
| |
| if not import_path or not provider_id: |
| continue |
| |
| provider_names[provider_id] = provider_name |
| |
| obj = import_symbol(import_path) |
| if obj is None: |
| total_failed += 1 |
| continue |
| |
| try: |
| mro, params = extract_params(obj) |
| except Exception as e: |
| print(f" ERROR extracting params for {import_path}: {e}") |
| total_failed += 1 |
| continue |
| |
| provider_classes[provider_id][import_path] = { |
| "name": class_name, |
| "type": module_type, |
| "mro": mro, |
| "parameters": params, |
| } |
| |
| total_processed += 1 |
| total_params += len(params) |
| |
| if i % 100 == 0: |
| print(f" Processed {i}/{len(modules)} modules...") |
| |
| return provider_classes, provider_names, total_processed, total_failed, total_params |
| |
| |
| def _write_parameter_files( |
| provider_classes: dict[str, dict[str, dict]], |
| provider_names: dict[str, str], |
| provider_versions: dict[str, str], |
| generated_at: str, |
| ) -> None: |
| """Write per-provider parameter JSON files.""" |
| for output_dir in OUTPUT_DIRS: |
| if not output_dir.parent.exists(): |
| continue |
| |
| written = 0 |
| for pid, classes in provider_classes.items(): |
| version = provider_versions.get(pid) |
| if not version: |
| print(f" WARN: no version found for {pid}, skipping") |
| continue |
| |
| version_dir = output_dir / "versions" / pid / version |
| version_dir.mkdir(parents=True, exist_ok=True) |
| |
| provider_data = validate_provider_parameters( |
| { |
| "provider_id": pid, |
| "provider_name": provider_names.get(pid, pid), |
| "version": version, |
| "generated_at": generated_at, |
| "classes": classes, |
| } |
| ) |
| with open(version_dir / "parameters.json", "w") as f: |
| json.dump(provider_data, f, separators=(",", ":")) |
| written += 1 |
| |
| print(f"Wrote {written} provider parameter files to {output_dir}/versions/") |
| |
| |
| def _fetch_inventories( |
| provider_ids: set[str], |
| provider_yamls: dict[str, dict], |
| ) -> dict[str, dict[str, str]]: |
| """Fetch Sphinx inventory files in parallel for all providers.""" |
| package_names: dict[str, str] = {} |
| for pid in provider_ids: |
| py = provider_yamls.get(pid, {}) |
| package_names[pid] = py.get("package-name", f"apache-airflow-providers-{pid}") |
| |
| def _fetch_and_parse(pid: str) -> tuple[str, dict[str, str] | None]: |
| inv_path = fetch_provider_inventory(package_names[pid]) |
| if inv_path: |
| try: |
| return pid, read_inventory(inv_path) |
| except Exception as e: |
| print(f" Warning: Could not parse inventory for {pid}: {e}") |
| return pid, None |
| return pid, None |
| |
| inventories: dict[str, dict[str, str]] = {} |
| with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: |
| futures = {executor.submit(_fetch_and_parse, pid): pid for pid in provider_ids} |
| for future in concurrent.futures.as_completed(futures): |
| pid, inv = future.result() |
| if inv: |
| inventories[pid] = inv |
| |
| return inventories |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser(description="Extract provider parameters and modules") |
| parser.add_argument( |
| "--provider", |
| default=None, |
| help="Only process this provider ID (e.g. 'amazon'). Skips modules.json write.", |
| ) |
| parser.add_argument( |
| "--providers-json", |
| default=None, |
| help="Path to providers.json (overrides default search paths).", |
| ) |
| args = parser.parse_args() |
| |
| print("Airflow Registry Parameter & Module Extractor") |
| print("=" * 50) |
| |
| if args.providers_json: |
| providers_json_path = Path(args.providers_json) |
| else: |
| providers_json_path = find_json(PROVIDERS_JSON_CANDIDATES, "providers.json") |
| with open(providers_json_path) as f: |
| providers_data = json.load(f) |
| |
| provider_versions: dict[str, str] = {} |
| for p in providers_data.get("providers", []): |
| provider_versions[p["id"]] = p["version"] |
| |
| generated_at = datetime.now(timezone.utc).isoformat() |
| _main_discover(provider_versions, generated_at, only_provider=args.provider) |
| |
| print("\nDone!") |
| |
| |
| def _main_discover( |
| provider_versions: dict[str, str], |
| generated_at: str, |
| only_provider: str | None = None, |
| ) -> None: |
| """Runtime discovery: find classes from provider.yaml files, produce modules.json and parameters. |
| |
| When only_provider is set, only that provider is scanned and modules.json is NOT written |
| (it would be incomplete). This enables parallel backfills since the only output is |
| the per-provider parameters.json file. |
| """ |
| provider_yaml_paths = sorted(PROVIDERS_DIR.rglob("provider.yaml")) |
| print(f"Found {len(provider_yaml_paths)} provider.yaml files") |
| |
| base_classes = load_base_classes() |
| print(f"Loaded {len(base_classes)} base classes: {', '.join(sorted(base_classes))}") |
| |
| # Load all provider.yaml data and map provider_id -> yaml dict / path |
| provider_yamls_by_id: dict[str, dict] = {} |
| provider_paths_by_id: dict[str, Path] = {} |
| for yaml_path in provider_yaml_paths: |
| with open(yaml_path) as f: |
| py = yaml.safe_load(f) |
| pid = py.get("package-name", "").replace("apache-airflow-providers-", "") |
| if pid: |
| provider_yamls_by_id[pid] = py |
| provider_paths_by_id[pid] = yaml_path |
| |
| # Filter to single provider if requested |
| if only_provider: |
| if only_provider not in provider_paths_by_id: |
| print(f"ERROR: provider '{only_provider}' not found in provider.yaml files") |
| sys.exit(1) |
| provider_paths_by_id = {only_provider: provider_paths_by_id[only_provider]} |
| provider_yamls_by_id = {only_provider: provider_yamls_by_id[only_provider]} |
| print(f"Filtering to provider: {only_provider}") |
| |
| # Fetch Sphinx inventories in parallel |
| print("Fetching Sphinx inventory files ...") |
| inventories = _fetch_inventories(set(provider_yamls_by_id), provider_yamls_by_id) |
| print(f" {len(inventories)}/{len(provider_yamls_by_id)} inventories loaded") |
| |
| all_discovered: list[dict] = [] |
| providers_seen: set[str] = set() |
| |
| for pid, yaml_path in sorted(provider_paths_by_id.items()): |
| version = provider_versions.get(pid, "") |
| discovered = discover_classes_from_provider( |
| yaml_path, |
| base_classes, |
| inventory=inventories.get(pid), |
| version=version, |
| ) |
| all_discovered.extend(discovered) |
| for d in discovered: |
| providers_seen.add(d["provider_id"]) |
| |
| print(f"\nDiscovered {len(all_discovered)} classes from {len(providers_seen)} providers") |
| |
| # Deduplicate by ID |
| seen_ids: set[str] = set() |
| unique_modules: list[dict] = [] |
| for m in all_discovered: |
| mid = m["id"] |
| if mid not in seen_ids: |
| seen_ids.add(mid) |
| unique_modules.append(m) |
| all_discovered = unique_modules |
| print(f"Deduplicated to {len(all_discovered)} unique modules") |
| |
| # Write modules.json only when doing a full build (no --provider filter). |
| # With --provider, the output would be incomplete and would clobber the |
| # full modules.json from a previous build. |
| if not only_provider: |
| modules_json = validate_modules_catalog({"modules": all_discovered}) |
| output_dirs = [SCRIPT_DIR, AIRFLOW_ROOT / "registry" / "src" / "_data"] |
| for out_dir in output_dirs: |
| if not out_dir.parent.exists(): |
| continue |
| out_dir.mkdir(parents=True, exist_ok=True) |
| with open(out_dir / "modules.json", "w") as f: |
| json.dump(modules_json, f, indent=2) |
| print(f"Wrote {len(all_discovered)} modules to {out_dir / 'modules.json'}") |
| |
| # Write runtime_modules.json (debug/stats file) |
| runtime_output = { |
| "generated_at": generated_at, |
| "discovery_method": "runtime", |
| "stats": { |
| "total_classes": len(all_discovered), |
| "total_providers": len(providers_seen), |
| }, |
| "classes": all_discovered, |
| } |
| runtime_json_path = SCRIPT_DIR / "runtime_modules.json" |
| with open(runtime_json_path, "w") as f: |
| json.dump(runtime_output, f, indent=2) |
| print(f"Wrote {runtime_json_path}") |
| else: |
| print("Skipping modules.json write (--provider mode)") |
| |
| # Extract parameters |
| print("\nExtracting parameters from runtime-discovered classes...") |
| provider_classes, provider_names, total_processed, total_failed, total_params = ( |
| _extract_params_from_modules(all_discovered) |
| ) |
| |
| print(f"\nProcessed {total_processed} classes, {total_failed} failed imports") |
| print(f"Extracted {total_params} total parameters") |
| print(f"Across {len(provider_classes)} providers") |
| |
| _write_parameter_files(provider_classes, provider_names, provider_versions, generated_at) |
| |
| |
| if __name__ == "__main__": |
| sys.exit(main() or 0) |