blob: aa01e56065e405dfc98ad1879996cf63eacf80bb [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import ast
import sys
from functools import lru_cache
from pathlib import Path
allowed_warnings: dict[str, tuple[str, ...]] = {
"airflow": (
"airflow.exceptions.RemovedInAirflow3Warning",
"airflow.utils.context.AirflowContextDeprecationWarning",
),
"providers": ("airflow.exceptions.AirflowProviderDeprecationWarning",),
}
compatible_decorators: frozenset[tuple[str, ...]] = frozenset(
[
# PEP 702 decorators
("warnings", "deprecated"),
("typing_extensions", "deprecated"),
# `Deprecated` package decorators
("deprecated", "deprecated"),
("deprecated", "classic", "deprecated"),
]
)
@lru_cache(maxsize=None)
def allowed_group_warnings(group: str) -> tuple[str, tuple[str, ...]]:
group_warnings = allowed_warnings[group]
if len(group_warnings) == 1:
return f"expected {group_warnings[0]} type", group_warnings
else:
return f"expected one of {', '.join(group_warnings)} types", group_warnings
def built_import_from(import_from: ast.ImportFrom) -> list[str]:
result: list[str] = []
module_name = import_from.module
if not module_name:
return result
imports_levels = module_name.count(".") + 1
for import_path in compatible_decorators:
if imports_levels >= len(import_path):
continue
if module_name != ".".join(import_path[:imports_levels]):
continue
for name in import_from.names:
if name.name == import_path[imports_levels]:
alias: str = name.asname or name.name
remaining_part = len(import_path) - imports_levels - 1
if remaining_part > 0:
alias = ".".join([alias, *import_path[-remaining_part:]])
result.append(alias)
return result
def built_import(import_clause: ast.Import) -> list[str]:
result = []
for name in import_clause.names:
module_name = name.name
imports_levels = module_name.count(".") + 1
for import_path in compatible_decorators:
if imports_levels > len(import_path):
continue
if module_name != ".".join(import_path[:imports_levels]):
continue
alias: str = name.asname or module_name
remaining_part = len(import_path) - imports_levels
if remaining_part > 0:
alias = ".".join([alias, *import_path[-remaining_part:]])
result.append(alias)
return result
def found_compatible_decorators(mod: ast.Module) -> tuple[str, ...]:
result = []
for node in mod.body:
if not isinstance(node, (ast.ImportFrom, ast.Import)):
continue
result.extend(built_import_from(node) if isinstance(node, ast.ImportFrom) else built_import(node))
return tuple(sorted(set(result)))
def resolve_name(obj: ast.Attribute | ast.Name) -> str:
name = ""
while True:
if isinstance(obj, ast.Name):
name = f"{obj.id}.{name}" if name else obj.id
break
elif isinstance(obj, ast.Attribute):
name = f"{obj.attr}.{name}" if name else obj.attr
obj = obj.value # type: ignore[assignment]
else:
msg = f"Expected to got ast.Name or ast.Attribute but got {type(obj).__name__!r}."
raise SystemExit(msg)
return name
def resolve_decorator_name(obj: ast.Call | ast.Attribute | ast.Name) -> str:
return resolve_name(obj.func if isinstance(obj, ast.Call) else obj) # type: ignore[arg-type]
def check_decorators(mod: ast.Module, file: str, file_group: str) -> int:
if not (decorators_names := found_compatible_decorators(mod)):
# There are no expected decorators into module, exit early
return 0
errors = 0
for node in ast.walk(mod):
if not hasattr(node, "decorator_list"):
continue
for decorator in node.decorator_list:
decorator_name = resolve_decorator_name(decorator)
if decorator_name not in decorators_names:
continue
expected_types, warns_types = allowed_group_warnings(file_group)
category_keyword: ast.keyword | None = next(
filter(lambda k: k and k.arg == "category", decorator.keywords), None
)
if category_keyword is None:
errors += 1
print(
f"{file}:{decorator.lineno}: Missing `category` keyword on decorator @{decorator_name}, "
f"{expected_types}"
)
continue
elif not hasattr(category_keyword, "value"):
continue
category_value_ast = category_keyword.value
warns_types = allowed_warnings[file_group]
if isinstance(category_value_ast, (ast.Name, ast.Attribute)):
category_value = resolve_name(category_value_ast)
if not any(cv.endswith(category_value) for cv in warns_types):
errors += 1
print(
f"{file}:{category_keyword.lineno}: "
f"category={category_value}, but {expected_types}"
)
elif isinstance(category_value_ast, ast.Constant):
errors += 1
print(
f"{file}:{category_keyword.lineno}: "
f"category=Literal[{category_value_ast.value!r}], but {expected_types}"
)
return errors
def check_file(file: str) -> int:
file_path = Path(file)
if not file_path.as_posix().startswith("airflow"):
# Not expected file, exit early
return 0
file_group = "providers" if file_path.as_posix().startswith("airflow/providers") else "airflow"
ast_module = ast.parse(file_path.read_text("utf-8"), file)
errors = 0
errors += check_decorators(ast_module, file, file_group)
return errors
def main(*args: str) -> int:
errors = sum(check_file(file) for file in args[1:])
if not errors:
return 0
print(f"Found {errors} error{'s' if errors > 1 else ''}.")
return 1
if __name__ == "__main__":
sys.exit(main(*sys.argv))