| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import ast |
| import importlib.util |
| import inspect |
| import itertools |
| import pathlib |
| import sys |
| import warnings |
| |
| import yaml |
| from rich.console import Console |
| |
| try: |
| from yaml import CSafeLoader as SafeLoader |
| except ImportError: |
| from yaml import SafeLoader # type: ignore |
| |
| console = Console(width=400, color_system="standard") |
| ROOT_DIR = pathlib.Path(__file__).resolve().parents[2] |
| |
| provider_files_pattern = pathlib.Path(ROOT_DIR, "airflow", "providers").rglob("provider.yaml") |
| errors: list[str] = [] |
| |
| OPERATORS: list[str] = ["sensors", "operators"] |
| CLASS_IDENTIFIERS: list[str] = ["sensor", "operator"] |
| |
| TEMPLATE_TYPES: list[str] = ["template_fields"] |
| |
| |
| class InstanceFieldExtractor(ast.NodeVisitor): |
| def __init__(self): |
| self.current_class = None |
| self.instance_fields = [] |
| |
| def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: |
| if node.name == "__init__": |
| self.generic_visit(node) |
| return node |
| |
| def visit_Assign(self, node: ast.Assign) -> ast.Assign: |
| fields = [] |
| for target in node.targets: |
| if isinstance(target, ast.Attribute): |
| fields.append(target.attr) |
| if fields: |
| self.instance_fields.extend(fields) |
| return node |
| |
| def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: |
| if isinstance(node.target, ast.Attribute): |
| self.instance_fields.append(node.target.attr) |
| return node |
| |
| |
| def get_template_fields_and_class_instance_fields(cls): |
| """ |
| 1.This method retrieves the operator class and obtains all its parent classes using the method resolution order (MRO). |
| 2. It then gathers the templated fields declared in both the operator class and its parent classes. |
| 3. Finally, it retrieves the instance fields of the operator class, specifically the self.fields attributes. |
| """ |
| all_template_fields = [] |
| class_instance_fields = [] |
| |
| all_classes = cls.__mro__ |
| for current_class in all_classes: |
| if current_class.__init__ is not object.__init__: |
| cls_attr = current_class.__dict__ |
| for template_type in TEMPLATE_TYPES: |
| fields = cls_attr.get(template_type) |
| if fields: |
| all_template_fields.extend(fields) |
| |
| tree = ast.parse(inspect.getsource(current_class)) |
| visitor = InstanceFieldExtractor() |
| visitor.visit(tree) |
| if visitor.instance_fields: |
| class_instance_fields.extend(visitor.instance_fields) |
| return all_template_fields, class_instance_fields |
| |
| |
| def load_yaml_data() -> dict: |
| """ |
| It loads all the provider YAML files and retrieves the module referenced within each YAML file. |
| """ |
| package_paths = sorted(str(path) for path in provider_files_pattern) |
| result = {} |
| for provider_yaml_path in package_paths: |
| with open(provider_yaml_path) as yaml_file: |
| provider = yaml.load(yaml_file, SafeLoader) |
| rel_path = pathlib.Path(provider_yaml_path).relative_to(ROOT_DIR).as_posix() |
| result[rel_path] = provider |
| return result |
| |
| |
| def get_providers_modules() -> list[str]: |
| modules_container = [] |
| result = load_yaml_data() |
| |
| for (_, provider_data), resource_type in itertools.product(result.items(), OPERATORS): |
| if provider_data.get(resource_type): |
| for data in provider_data.get(resource_type): |
| modules_container.extend(data.get("python-modules")) |
| |
| return modules_container |
| |
| |
| def is_class_eligible(name: str) -> bool: |
| for op in CLASS_IDENTIFIERS: |
| if name.lower().endswith(op): |
| return True |
| return False |
| |
| |
| def get_eligible_classes(all_classes): |
| """ |
| Filter the results to include only classes that end with `Sensor` or `Operator`. |
| |
| """ |
| |
| eligible_classes = [(name, cls) for name, cls in all_classes if is_class_eligible(name)] |
| return eligible_classes |
| |
| |
| def iter_check_template_fields(module: str): |
| """ |
| 1. This method imports the providers module and retrieves all the classes defined within it. |
| 2. It then filters and selects classes related to operators or sensors by checking if the class name ends with "Operator" or "Sensor." |
| 3. For each operator class, it validates the template fields by inspecting the class instance fields. |
| """ |
| with warnings.catch_warnings(record=True): |
| imported_module = importlib.import_module(module) |
| classes = inspect.getmembers(imported_module, inspect.isclass) |
| op_classes = get_eligible_classes(classes) |
| |
| for op_class_name, cls in op_classes: |
| if cls.__module__ == module: |
| templated_fields, class_instance_fields = get_template_fields_and_class_instance_fields(cls) |
| |
| for field in templated_fields: |
| if field not in class_instance_fields: |
| errors.append(f"{module}: {op_class_name}: {field}") |
| |
| |
| if __name__ == "__main__": |
| provider_modules = get_providers_modules() |
| |
| if len(sys.argv) > 1: |
| py_files = sorted(sys.argv[1:]) |
| modules_to_validate = [ |
| module_name |
| for pyfile in py_files |
| if (module_name := pyfile.rstrip(".py").replace("/", ".")) in provider_modules |
| ] |
| else: |
| modules_to_validate = provider_modules |
| |
| [iter_check_template_fields(module) for module in modules_to_validate] |
| if errors: |
| console.print("[red]Found Invalid template fields:") |
| for error in errors: |
| console.print(f"[red]Error:[/] {error}") |
| |
| sys.exit(len(errors)) |