blob: 8a9252d76441f4086c1678948b7f6b42331ae16c [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# /// script
# requires-python = ">=3.10,<3.11"
# dependencies = [
# "libcst>=1.8.1",
# ]
# ///
from __future__ import annotations
import ast
import glob
import itertools
import os
import sys
from collections.abc import Iterator
import libcst as cst
from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, os.pardir))
DEFERRABLE_DOC = (
"https://github.com/apache/airflow/blob/main/airflow-core/docs/"
"authoring-and-scheduling/deferring.rst#writing-deferrable-operators"
)
class DefaultDeferrableVisitor(ast.NodeVisitor):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, *kwargs)
self.error_linenos: list[int] = []
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
if node.name == "__init__":
args = node.args
arguments = reversed([*args.args, *args.posonlyargs, *args.kwonlyargs])
defaults = reversed([*args.defaults, *args.kw_defaults])
for argument, default in itertools.zip_longest(arguments, defaults):
# argument is not deferrable
if argument is None or argument.arg != "deferrable":
continue
# argument is deferrable, but comes with no default value
if default is None:
self.error_linenos.append(argument.lineno)
continue
# argument is deferrable, but the default value is not valid
if not _is_valid_deferrable_default(default):
self.error_linenos.append(default.lineno)
return node
class DefaultDeferrableTransformer(cst.CSTTransformer):
def leave_Param(self, original_node: cst.Param, updated_node: cst.Param) -> cst.Param:
if original_node.name.value == "deferrable":
expected_default_cst = cst.parse_expression(
'conf.getboolean("operators", "default_deferrable", fallback=False)'
)
if updated_node.default and updated_node.default.deep_equals(expected_default_cst):
return updated_node
return updated_node.with_changes(default=expected_default_cst)
return updated_node
def _is_valid_deferrable_default(default: ast.AST) -> bool:
"""Check whether default is 'conf.getboolean("operators", "default_deferrable", fallback=False)'"""
return ast.unparse(default) == "conf.getboolean('operators', 'default_deferrable', fallback=False)"
def iter_check_deferrable_default_errors(module_filename: str) -> Iterator[str]:
ast_tree = ast.parse(open(module_filename).read())
visitor = DefaultDeferrableVisitor()
visitor.visit(ast_tree)
# We check the module using the ast once and then fix it through cst if needed.
# The primary reason we don't do it all through cst is performance.
if visitor.error_linenos:
_fix_invalid_deferrable_default_value(module_filename)
yield from (f"{module_filename}:{lineno}" for lineno in visitor.error_linenos)
def _fix_invalid_deferrable_default_value(module_filename: str) -> None:
context = CodemodContext(filename=module_filename)
AddImportsVisitor.add_needed_import(context, "airflow.configuration", "conf")
transformer = DefaultDeferrableTransformer()
source_cst_tree = cst.parse_module(open(module_filename).read())
modified_cst_tree = AddImportsVisitor(context).transform_module(source_cst_tree.visit(transformer))
if not source_cst_tree.deep_equals(modified_cst_tree):
with open(module_filename, "w") as writer:
writer.write(modified_cst_tree.code)
def main() -> int:
modules = itertools.chain(
glob.glob(f"{ROOT_DIR}/**/sensors/**.py", recursive=True),
glob.glob(f"{ROOT_DIR}/**/operators/**.py", recursive=True),
)
errors = [error for module in modules for error in iter_check_deferrable_default_errors(module)]
if errors:
print("Incorrect deferrable default values detected at:")
for error in errors:
print(f" {error}")
print(
"""Please set the default value of deferrable to """
""""conf.getboolean("operators", "default_deferrable", fallback=False)"\n"""
f"See: {DEFERRABLE_DOC}\n"
)
return len(errors)
if __name__ == "__main__":
sys.exit(main())