blob: ad54dfebdb33e14e0e66e74089a6ee727f847109 [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# /// script
# requires-python = ">=3.10,<3.11"
# dependencies = [
# "rich>=13.6.0",
# "python-dateutil>=2.9.0",
# ]
# ///
from __future__ import annotations
import ast
import sys
from datetime import date
from functools import cache
from pathlib import Path
from typing import Any
from dateutil.relativedelta import relativedelta
from rich.console import Console
console = Console(color_system="standard", width=200)
allowed_warnings: dict[str, tuple[str, ...]] = {
"providers": ("airflow.exceptions.AirflowProviderDeprecationWarning",),
}
compatible_decorators: frozenset[tuple[str, ...]] = frozenset(
[
# PEP 702 decorators
("warnings", "deprecated"),
("typing_extensions", "deprecated"),
# `Deprecated` package decorators
("deprecated", "deprecated"),
("deprecated", "classic", "deprecated"),
# Experimental Google Provider's decorator with structured deprecation details
("airflow", "providers", "google", "common", "deprecated", "deprecated"),
]
)
end_of_life_deprecation_warnings: dict[str, tuple[str, ...]] = {
"airflow/providers/google/": ("AirflowProviderDeprecationWarning",),
}
def is_file_under_eol_deprecation(file_path: str, warning_class: str) -> bool:
for prefix, warnings in end_of_life_deprecation_warnings.items():
if file_path.startswith(prefix):
return bool(warning_class in warnings)
return False
def validate_end_of_life_deprecation_warnings(file_path: str, decorator: Any, warning_class: str) -> int:
_errors = 0
if not is_file_under_eol_deprecation(file_path=file_path, warning_class=warning_class):
return 0
if not get_decorator_argument(decorator, "planned_removal_date"):
expected_date = date.today() + relativedelta(months=6)
if expected_date.day > 1:
_date = date.today() + relativedelta(months=7)
expected_date = date(day=1, month=_date.month, year=_date.year)
expected_date_str: str = expected_date.strftime("%B %d, %Y")
_errors += 1
console.print(
f"{file_path}:{decorator.lineno}: "
"The 'planned_removal_date' parameter is missing in the `@deprecated(...)` call.\n"
"Please provide the date in the format 'Month DD, YYYY' "
"after which the deprecated object should be removed.\n"
"The recommended date is at least six months ahead: "
f"'planned_removal_date=\"{expected_date_str}\"'"
)
return _errors
def get_decorator_argument(decorator: ast.Call, argument_name: str) -> ast.keyword | None:
return next(filter(lambda k: k and k.arg == argument_name, decorator.keywords), None) # type: ignore[arg-type]
@cache
def allowed_group_warnings(group: str) -> tuple[str, tuple[str, ...]]:
group_warnings = allowed_warnings[group]
if len(group_warnings) == 1:
return f"expected {group_warnings[0]} type", group_warnings
return f"expected one of {', '.join(group_warnings)} types", group_warnings
def built_import_from(import_from: ast.ImportFrom) -> list[str]:
result: list[str] = []
module_name = import_from.module
if not module_name:
return result
imports_levels = module_name.count(".") + 1
for import_path in compatible_decorators:
if imports_levels >= len(import_path):
continue
if module_name != ".".join(import_path[:imports_levels]):
continue
for name in import_from.names:
if name.name == import_path[imports_levels]:
alias: str = name.asname or name.name
remaining_part = len(import_path) - imports_levels - 1
if remaining_part > 0:
alias = ".".join([alias, *import_path[-remaining_part:]])
result.append(alias)
return result
def built_import(import_clause: ast.Import) -> list[str]:
result = []
for name in import_clause.names:
module_name = name.name
imports_levels = module_name.count(".") + 1
for import_path in compatible_decorators:
if imports_levels > len(import_path):
continue
if module_name != ".".join(import_path[:imports_levels]):
continue
alias: str = name.asname or module_name
remaining_part = len(import_path) - imports_levels
if remaining_part > 0:
alias = ".".join([alias, *import_path[-remaining_part:]])
result.append(alias)
return result
def found_compatible_decorators(mod: ast.Module) -> tuple[str, ...]:
result = []
for node in mod.body:
if not isinstance(node, (ast.ImportFrom, ast.Import)):
continue
result.extend(built_import_from(node) if isinstance(node, ast.ImportFrom) else built_import(node))
return tuple(sorted(set(result)))
def resolve_name(obj: ast.Attribute | ast.Name) -> str:
name = ""
while True:
if isinstance(obj, ast.Name):
name = f"{obj.id}.{name}" if name else obj.id
break
if isinstance(obj, ast.Attribute):
name = f"{obj.attr}.{name}" if name else obj.attr
obj = obj.value # type: ignore[assignment]
else:
msg = f"Expected to got ast.Name or ast.Attribute but got {type(obj).__name__!r}."
raise SystemExit(msg)
return name
def resolve_decorator_name(obj: ast.Call | ast.Attribute | ast.Name) -> str:
return resolve_name(obj.func if isinstance(obj, ast.Call) else obj) # type: ignore[arg-type]
def check_decorators(mod: ast.Module, file: str, file_group: str) -> int:
if not (decorators_names := found_compatible_decorators(mod)):
# There are no expected decorators into module, exit early
return 0
errors = 0
for node in ast.walk(mod):
if not hasattr(node, "decorator_list"):
continue
for decorator in node.decorator_list:
decorator_name = resolve_decorator_name(decorator)
if decorator_name not in decorators_names:
continue
expected_types, warns_types = allowed_group_warnings(file_group)
category_keyword = get_decorator_argument(decorator, "category")
if category_keyword is None:
errors += 1
print(
f"{file}:{decorator.lineno}: Missing `category` keyword on decorator @{decorator_name}, "
f"{expected_types}"
)
continue
if not hasattr(category_keyword, "value"):
continue
category_value_ast = category_keyword.value
warns_types = allowed_warnings[file_group]
if isinstance(category_value_ast, (ast.Name, ast.Attribute)):
category_value = resolve_name(category_value_ast)
if not any(cv.endswith(category_value) for cv in warns_types):
errors += 1
print(
f"{file}:{category_keyword.lineno}: category={category_value}, but {expected_types}"
)
errors += validate_end_of_life_deprecation_warnings(
file_path=file, decorator=decorator, warning_class=category_value
)
elif isinstance(category_value_ast, ast.Constant):
errors += 1
print(
f"{file}:{category_keyword.lineno}: "
f"category=Literal[{category_value_ast.value!r}], but {expected_types}"
)
return errors
def check_file(file: str) -> int:
file_path = Path(file)
if not file_path.as_posix().startswith("airflow"):
# Not expected file, exit early
return 0
file_group = "providers" if file_path.as_posix().startswith("airflow/providers") else "airflow"
ast_module = ast.parse(file_path.read_text("utf-8"), file)
errors = 0
errors += check_decorators(ast_module, file, file_group)
return errors
def main(*args: str) -> int:
errors = sum(check_file(file) for file in args[1:])
if not errors:
return 0
print(f"Found {errors} error{'s' if errors > 1 else ''}.")
return 1
if __name__ == "__main__":
sys.exit(main(*sys.argv))