| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import json |
| import os |
| import sys |
| from ast import Import, ImportFrom, NodeVisitor, parse |
| from collections import defaultdict |
| from os.path import dirname, sep |
| from typing import Dict, List, Optional, Tuple |
| |
| from setup import PROVIDERS_REQUIREMENTS |
| |
| sys.path.append(os.path.join(dirname(__file__), os.pardir)) |
| |
| |
| AIRFLOW_PROVIDERS_FILE_PREFIX = f"airflow{sep}providers{sep}" |
| AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX = f"tests{sep}providers{sep}" |
| AIRFLOW_PROVIDERS_IMPORT_PREFIX = "airflow.providers." |
| |
| # List of information messages generated |
| infos: List[str] = [] |
| # List of warnings generated |
| warnings: List[str] = [] |
| # list of errors generated |
| errors: List[str] = [] |
| |
| # store dependencies |
| dependencies: Dict[str, List[str]] = defaultdict(list) |
| |
| |
| def find_provider(provider_elements: List[str]) -> Optional[str]: |
| """ |
| Finds provider name from the list of elements provided. It looks the providers up |
| in PROVIDERS_REQUIREMENTS dict taken from the setup.py. |
| |
| :param provider_elements: array of elements of the path (split) |
| :return: provider name or None if no provider could be found |
| """ |
| provider = "" |
| separator = "" |
| provider_keys = PROVIDERS_REQUIREMENTS.keys() |
| for element in provider_elements: |
| provider = provider + separator + element |
| if provider in provider_keys: |
| return provider |
| separator = "." |
| return None |
| |
| |
| def get_provider_from_file_name(file_name: str) -> Optional[str]: |
| """ |
| Retrieves provider name from file name |
| :param file_name: name of the file |
| :return: provider name or None if no provider could be found |
| """ |
| if ( |
| AIRFLOW_PROVIDERS_FILE_PREFIX not in file_name |
| and AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX not in file_name |
| ): |
| # We should only check file that are provider |
| errors.append(f"Wrong file not in the providers package = {file_name}") |
| return None |
| suffix = get_file_suffix(file_name) |
| split_path = suffix.split(sep)[2:] |
| provider = find_provider(split_path) |
| if not provider and file_name.endswith("__init__.py"): |
| infos.append(f"Skipped file = {file_name}") |
| elif not provider: |
| warnings.append(f"Provider not found for path = {file_name}") |
| return provider |
| |
| |
| def get_file_suffix(file_name) -> Optional[str]: |
| if AIRFLOW_PROVIDERS_FILE_PREFIX in file_name: |
| return file_name[file_name.find(AIRFLOW_PROVIDERS_FILE_PREFIX) :] |
| if AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX in file_name: |
| return file_name[file_name.find(AIRFLOW_TESTS_PROVIDERS_FILE_PREFIX) :] |
| return None |
| |
| |
| def get_provider_from_import(import_name: str) -> Optional[str]: |
| """ |
| Retrieves provider name from file name |
| :param import_name: name of the import |
| :return: provider name or None if no provider could be found |
| """ |
| if AIRFLOW_PROVIDERS_IMPORT_PREFIX not in import_name: |
| # skip silently - we expect non-providers imports |
| return None |
| suffix = import_name[import_name.find(AIRFLOW_PROVIDERS_IMPORT_PREFIX) :] |
| split_import = suffix.split(".")[2:] |
| provider = find_provider(split_import) |
| if not provider: |
| warnings.append(f"Provider not found for import = {import_name}") |
| return provider |
| |
| |
| class ImportFinder(NodeVisitor): |
| """ |
| AST visitor that collects all imported names in its imports |
| """ |
| |
| def __init__(self, filename: str) -> None: |
| self.imports: List[str] = [] |
| self.filename = filename |
| self.handled_import_exception = List[str] |
| self.tried_imports: List[str] = [] |
| |
| def process_import(self, import_name: str) -> None: |
| self.imports.append(import_name) |
| |
| def get_import_name_from_import_from(self, node: ImportFrom) -> List[str]: |
| """ |
| Retrieves import name from the "from" import. |
| :param node: ImportFrom name |
| :return: import name |
| """ |
| import_names: List[str] = [] |
| for alias in node.names: |
| name = alias.name |
| fullname = f'{node.module}.{name}' if node.module else name |
| import_names.append(fullname) |
| return import_names |
| |
| def visit_Import(self, node: Import): |
| for alias in node.names: |
| self.process_import(alias.name) |
| |
| def visit_ImportFrom(self, node: ImportFrom): |
| if node.module == '__future__': |
| return |
| for fullname in self.get_import_name_from_import_from(node): |
| self.process_import(fullname) |
| |
| |
| def get_imports_from_file(file_name: str) -> List[str]: |
| """ |
| Retrieves imports from file. |
| :param file_name: name of the file |
| :return: list of import names |
| """ |
| try: |
| with open(file_name, encoding="utf-8") as f: |
| root = parse(f.read(), file_name) |
| except Exception: |
| print(f"Error when opening file {file_name}", file=sys.stderr) |
| raise |
| visitor = ImportFinder(file_name) |
| visitor.visit(root) |
| return visitor.imports |
| |
| |
| def check_if_different_provider_used(file_name: str) -> None: |
| file_provider = get_provider_from_file_name(file_name) |
| if not file_provider: |
| return |
| imports = get_imports_from_file(file_name) |
| for import_name in imports: |
| import_provider = get_provider_from_import(import_name) |
| if import_provider and file_provider != import_provider: |
| dependencies[file_provider].append(import_provider) |
| |
| |
| def parse_arguments() -> Tuple[str, str, str]: |
| import argparse |
| |
| parser = argparse.ArgumentParser( |
| description='Checks if dependencies between packages are handled correctly.' |
| ) |
| parser.add_argument( |
| "-f", "--provider-dependencies-file", help="Stores dependencies between providers in the file(.json)" |
| ) |
| parser.add_argument( |
| "-d", "--documentation-file", help="Updates package documentation in the file specified (.rst)" |
| ) |
| parser.add_argument('files', nargs='*') |
| args = parser.parse_args() |
| |
| if len(args.files) < 1: |
| parser.print_usage() |
| print() |
| sys.exit(2) |
| return args.files, args.provider_dependencies_file, args.documentation_file |
| |
| |
| PREFIX = " " |
| |
| HEADER = """ |
| ========================== =========================== |
| Package Extras |
| ========================== =========================== |
| """ |
| FOOTER = """========================== =========================== |
| |
| """ |
| |
| |
| def insert_documentation(deps_dict: Dict[str, List[str]], res: List[str]) -> None: |
| res += HEADER.splitlines(keepends=True) |
| for package, deps in deps_dict.items(): |
| deps_str = ",".join(deps) |
| res.append(f"{package:27}{deps_str}\n") |
| res += FOOTER.splitlines(keepends=True) |
| |
| |
| if __name__ == '__main__': |
| print() |
| files, provider_dependencies_file_name, documentation_file_name = parse_arguments() |
| num_files = 0 |
| for file in files: |
| check_if_different_provider_used(file) |
| num_files += 1 |
| print(f"Verified {num_files} files.") |
| if infos: |
| print("\nInformation messages:\n") |
| for info in infos: |
| print(PREFIX + info) |
| print(f"Total: {len(infos)} information messages.") |
| if warnings: |
| print("\nWarnings!\n") |
| for warning in warnings: |
| print(PREFIX + warning) |
| print(f"Total: {len(warnings)} warnings.") |
| if errors: |
| print("\nErrors!\n") |
| for error in errors: |
| print(PREFIX + error) |
| print(f"Total: {len(errors)} errors.") |
| unique_sorted_dependencies: Dict[str, List[str]] = {} |
| for key in sorted(dependencies.keys()): |
| unique_sorted_dependencies[key] = sorted(set(dependencies[key])) |
| if provider_dependencies_file_name: |
| with open(provider_dependencies_file_name, "w") as providers_file: |
| json.dump(unique_sorted_dependencies, providers_file, indent=2) |
| providers_file.write("\n") |
| print() |
| print(f"Written provider dependencies to the file {provider_dependencies_file_name}") |
| print() |
| if documentation_file_name: |
| with open(documentation_file_name, encoding="utf-8") as documentation_file: |
| text = documentation_file.readlines() |
| replacing = False |
| result: List[str] = [] |
| for line in text: |
| if line.startswith(" .. START PACKAGE DEPENDENCIES HERE"): |
| replacing = True |
| result.append(line) |
| insert_documentation(unique_sorted_dependencies, result) |
| if line.startswith(" .. END PACKAGE DEPENDENCIES HERE"): |
| replacing = False |
| if not replacing: |
| result.append(line) |
| with open(documentation_file_name, "w", encoding="utf-8") as documentation_file: |
| documentation_file.write("".join(result)) |
| print() |
| print(f"Written package extras to the file {documentation_file_name}") |
| print() |
| if errors: |
| print() |
| print("ERROR! Errors found during verification. Exiting!") |
| print() |
| sys.exit(1) |
| print() |
| print("Verification complete! Success!") |
| print() |