blob: 0ac6ae5c15c676246abbd0ad0a583cff4a3117a4 [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# /// script
# requires-python = ">=3.10,<3.11"
# dependencies = [
# "rich>=13.6.0",
# ]
# ///
from __future__ import annotations
import ast
import sys
from typing import Any
from rich.console import Console
console = Console(color_system="standard", width=200)
BASE_OPERATOR_CLASS_NAME = "BaseOperator"
def _is_operator(class_node: ast.ClassDef) -> bool:
"""
Check if a given class node is an operator, based of the string suffix of the base IDs
(ends with "BaseOperator").
TODO: Enhance this function to work with nested inheritance trees through dynamic imports.
:param class_node: The class node to check.
:return: True if the class definition is of an operator, False otherwise.
"""
for base in class_node.bases:
if isinstance(base, ast.Name) and base.id.endswith(BASE_OPERATOR_CLASS_NAME):
return True
return False
def _extract_template_fields(class_node: ast.ClassDef) -> list[str]:
"""
This method takes a class node as input and extracts the template fields from it.
Template fields are identified by an assignment statement where the target is a variable
named "template_fields" and the value is a tuple of constants.
:param class_node: The class node representing the class for which template fields need to be extracted.
:return: A list of template fields extracted from the class node.
"""
for class_item in class_node.body:
if isinstance(class_item, ast.Assign):
for target in class_item.targets:
if (
isinstance(target, ast.Name)
and target.id == "template_fields"
and isinstance(class_item.value, ast.Tuple)
):
return [str(elt.value) for elt in class_item.value.elts if isinstance(elt, ast.Constant)]
elif isinstance(class_item, ast.AnnAssign):
if (
isinstance(class_item.target, ast.Name)
and class_item.target.id == "template_fields"
and isinstance(class_item.value, ast.Tuple)
):
return [str(elt.value) for elt in class_item.value.elts if isinstance(elt, ast.Constant)]
return []
def _handle_parent_constructor_kwargs(
template_fields: list[str],
ctor_stmt: ast.stmt,
missing_assignments: list[str],
invalid_assignments: list[str],
) -> list[str]:
"""
This method checks if template fields are correctly assigned in a call to class parent's
constructor call.
It handles both the detection of missing assignments and invalid assignments.
It assumes that if the call is valid - the parent class will correctly assign the template
field.
TODO: Enhance this function to work with nested inheritance trees through dynamic imports.
:param missing_assignments: list[str] - List of template fields that have not been assigned a value.
:param ctor_stmt: ast.Expr - AST node representing the constructor statement.
:param invalid_assignments: list[str] - List of template fields that have been assigned incorrectly.
:param template_fields: list[str] - List of template fields to be assigned.
:return: list[str] - List of template fields that are still missing assignments.
"""
if isinstance(ctor_stmt, ast.Expr):
if (
isinstance(ctor_stmt.value, ast.Call)
and isinstance(ctor_stmt.value.func, ast.Attribute)
and isinstance(ctor_stmt.value.func.value, ast.Call)
and isinstance(ctor_stmt.value.func.value.func, ast.Name)
and ctor_stmt.value.func.value.func.id == "super"
):
for arg in ctor_stmt.value.keywords:
if arg.arg is not None and arg.arg in template_fields:
if not isinstance(arg.value, ast.Name) or arg.arg != arg.value.id:
invalid_assignments.append(arg.arg)
assigned_targets = [arg.arg for arg in ctor_stmt.value.keywords if arg.arg is not None]
return list(set(missing_assignments) - set(assigned_targets))
return missing_assignments
def _handle_constructor_statement(
template_fields: list[str],
ctor_stmt: ast.stmt,
missing_assignments: list[str],
invalid_assignments: list[str],
) -> list[str]:
"""
This method handles a single constructor statement by doing the following actions:
1. Removing assigned fields of template_fields from missing_assignments.
2. Detecting invalid assignments of template fields and adding them to invalid_assignments.
:param template_fields: Tuple of template fields.
:param ctor_stmt: Constructor statement (for example, self.field_name = param_name)
:param missing_assignments: List of missing assignments.
:param invalid_assignments: List of invalid assignments.
:return: List of missing assignments after handling the assigned targets.
"""
assigned_template_fields: list[str] = []
if isinstance(ctor_stmt, ast.Assign):
if isinstance(ctor_stmt.targets[0], ast.Attribute):
for target in ctor_stmt.targets:
if isinstance(target, ast.Attribute) and target.attr in template_fields:
if isinstance(ctor_stmt.value, ast.BoolOp) and isinstance(ctor_stmt.value.op, ast.Or):
_handle_assigned_field(
assigned_template_fields, invalid_assignments, target, ctor_stmt.value.values[0]
)
else:
_handle_assigned_field(
assigned_template_fields, invalid_assignments, target, ctor_stmt.value
)
elif isinstance(ctor_stmt.targets[0], ast.Tuple) and isinstance(ctor_stmt.value, ast.Tuple):
for target, value in zip(ctor_stmt.targets[0].elts, ctor_stmt.value.elts):
if isinstance(target, ast.Attribute):
_handle_assigned_field(assigned_template_fields, invalid_assignments, target, value)
elif isinstance(ctor_stmt, ast.AnnAssign):
if isinstance(ctor_stmt.target, ast.Attribute) and ctor_stmt.target.attr in template_fields:
_handle_assigned_field(
assigned_template_fields, invalid_assignments, ctor_stmt.target, ctor_stmt.value
)
return list(set(missing_assignments) - set(assigned_template_fields))
def _handle_assigned_field(
assigned_template_fields: list[str], invalid_assignments: list[str], target: ast.Attribute, value: Any
) -> None:
"""
Handle an assigned field by its value.
:param assigned_template_fields: A list to store the valid assigned fields.
:param invalid_assignments: A list to store the invalid assignments.
:param target: The target field.
:param value: The value of the field.
"""
if not isinstance(value, ast.Name) or target.attr != value.id:
invalid_assignments.append(target.attr)
else:
assigned_template_fields.append(target.attr)
def _check_constructor_template_fields(class_node: ast.ClassDef, template_fields: list[str]) -> int:
"""
This method checks a class's constructor for missing or invalid assignments of template fields.
When there isn't a constructor - it assumes that the template fields are defined in the parent's
constructor correctly.
TODO: Enhance this function to work with nested inheritance trees through dynamic imports.
:param class_node: the AST node representing the class definition
:param template_fields: a tuple of template fields
:return: the number of invalid template fields found
"""
count = 0
class_name = class_node.name
missing_assignments = template_fields.copy()
invalid_assignments: list[str] = []
init_flag: bool = False
for class_item in class_node.body:
if isinstance(class_item, ast.FunctionDef) and class_item.name == "__init__":
init_flag = True
for ctor_stmt in class_item.body:
missing_assignments = _handle_parent_constructor_kwargs(
template_fields, ctor_stmt, missing_assignments, invalid_assignments
)
missing_assignments = _handle_constructor_statement(
template_fields, ctor_stmt, missing_assignments, invalid_assignments
)
if init_flag and missing_assignments:
count += len(missing_assignments)
console.print(
f"{class_name}'s constructor lacks direct assignments for "
f"instance members corresponding to the following template fields "
f"(i.e., self.field_name = field_name or super.__init__(field_name=field_name, ...) ):"
)
console.print(f"[red]{missing_assignments}[/red]")
if invalid_assignments:
count += len(invalid_assignments)
console.print(
f"{class_name}'s constructor contains invalid assignments to the following instance "
f"members that should be corresponding to template fields "
f"(i.e., self.field_name = field_name):"
)
console.print(f"[red]{[f'self.{entry}' for entry in invalid_assignments]}[/red]")
return count
def main():
"""
Check missing or invalid template fields in constructors of providers' operators.
:return: The total number of errors found.
"""
err = 0
for path in sys.argv[1:]:
console.print(f"[yellow]{path}[/yellow]")
tree = ast.parse(open(path).read())
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and _is_operator(class_node=node):
template_fields = _extract_template_fields(node) or []
err += _check_constructor_template_fields(node, template_fields)
return err
if __name__ == "__main__":
sys.exit(main())