blob: 90bcbda5678a42f19db5c74c7e47cbdbaeaf1e5c [file] [log] [blame]
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# /// script
# requires-python = ">=3.10,<3.11"
# dependencies = [
# "packaging>=25",
# "pyyaml>=6.0.2",
# "rich>=13.6.0",
# ]
# ///
from __future__ import annotations
import ast
import pathlib
import sys
from collections.abc import Iterable
import yaml
from packaging.specifiers import SpecifierSet
from rich.console import Console
sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))
from common_prek_utils import get_provider_base_dir_from_path
console = Console(color_system="standard", width=200)
COMMON_SQL_PROVIDER_NAME: str = "apache-airflow-providers-common-sql"
COMMON_SQL_PROVIDER_MIN_COMPATIBLE_VERSIONS: str = "1.9.1"
COMMON_SQL_PROVIDER_LATEST_INCOMPATIBLE_VERSION: str = "1.9.0"
MAKE_COMMON_METHOD_NAME: str = "_make_common_data_structure"
def get_classes(file_path: pathlib.Path) -> Iterable[ast.ClassDef]:
"""Return a list of class declared in the given python file."""
module = ast.parse(file_path.read_text("utf-8"), filename=file_path.as_posix())
for node in ast.walk(module):
if isinstance(node, ast.ClassDef):
yield node
def is_subclass_of_dbapihook(node: ast.ClassDef) -> bool:
"""Return the subclass's name of a given class definition."""
for base in node.bases:
if isinstance(base, ast.Name) and base.id == "DbApiHook":
return True
return False
def has_make_common_data_structure_method(node: ast.ClassDef) -> bool:
"""Return True if the given class implements `_make_common_data_structure` method."""
for body_element in node.body:
if isinstance(body_element, ast.FunctionDef) and (body_element.name == MAKE_COMMON_METHOD_NAME):
return True
return False
def get_yaml_content(file_path: pathlib.Path) -> dict:
"""Load content of a yaml files."""
with open(file_path) as file:
return yaml.safe_load(file)
def get_common_sql_constraints(provider_metadata: dict) -> str | None:
"""Return the version constraints of `apache-airflow-providers-common-sql`."""
dependencies: list[str] = provider_metadata.get("dependencies", [])
for dependency in dependencies:
if dependency.startswith(COMMON_SQL_PROVIDER_NAME):
return dependency[len(COMMON_SQL_PROVIDER_NAME) :]
return None
def do_version_satisfies_constraints(
version: str,
max_incompatible_version=COMMON_SQL_PROVIDER_LATEST_INCOMPATIBLE_VERSION,
) -> bool:
"""Check if the `version_string` is constrained to at least >= 1.8.1."""
constraints: list[str] = [constraint.strip() for constraint in version.split(",")]
specifier_set = SpecifierSet(",".join(constraints))
return not specifier_set.contains(max_incompatible_version)
def check_sql_providers_dependency():
error_count: int = 0
for file_passed in sys.argv[1:]:
path = pathlib.Path(file_passed)
if not file_passed.startswith("providers/"):
continue
for clazz in get_classes(path):
if is_subclass_of_dbapihook(node=clazz) and has_make_common_data_structure_method(node=clazz):
provider_yaml_path: pathlib.Path = (
get_provider_base_dir_from_path(file_path=path) / "provider.yaml"
)
provider_metadata: dict = get_yaml_content(file_path=provider_yaml_path)
if version_constraint := get_common_sql_constraints(provider_metadata=provider_metadata):
if not do_version_satisfies_constraints(version=version_constraint):
error_count += 1
console.print(
f"\n[yellow]Provider {provider_metadata['name']} must have "
f"'{COMMON_SQL_PROVIDER_NAME}>={COMMON_SQL_PROVIDER_MIN_COMPATIBLE_VERSIONS}' as "
f"dependency, because `{clazz.name}` overrides the "
f"`{MAKE_COMMON_METHOD_NAME}` method."
)
if error_count:
console.print(
f"The `{MAKE_COMMON_METHOD_NAME}` method was introduced in {COMMON_SQL_PROVIDER_NAME} "
f"{COMMON_SQL_PROVIDER_MIN_COMPATIBLE_VERSIONS}. You cannot rely on an older version of this "
"provider to override this method."
)
sys.exit(1)
if __name__ == "__main__":
sys.exit(check_sql_providers_dependency())