blob: 2dbe702bdfa89da15430ddd2c30c2baa91f95a75 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import subprocess
from pathlib import Path
import libcst as cst
from datamodel_code_generator.format import CustomCodeFormatter
from libcst.helpers import parse_template_statement
AIRFLOW_ROOT_PATH = Path(__file__).parents[2].resolve()
def license_text() -> str:
license = (AIRFLOW_ROOT_PATH / "scripts" / "ci" / "license-templates" / "LICENSE.txt").read_text()
return "\n".join(f"# {line}" if line else "#" for line in license.splitlines()) + "\n"
class CodeFormatter(CustomCodeFormatter):
def apply(self, code: str) -> str:
code = license_text() + code
# Swap "class JsonValue[RootValue]:" for the import from pydantic
class JsonValueNodeRemover(cst.CSTTransformer):
def leave_ImportFrom(
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
if original_node.module and original_node.module.value == "pydantic":
new_names = updated_node.names + (cst.ImportAlias(name=cst.Name("JsonValue")),) # type: ignore[operator]
return updated_node.with_changes(names=new_names)
return super().leave_ImportFrom(original_node, updated_node)
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement] | cst.RemovalSentinel:
if original_node.name.value == "JsonValue":
return cst.RemoveFromParent()
return super().leave_ClassDef(original_node, updated_node)
class VersionConstInjtector(cst.CSTTransformer):
handled = False
def __init__(self, api_version: str) -> None:
self.api_version = api_version
super().__init__()
def leave_ImportFrom(
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
# Ensure we have `from typing import Final`
if original_node.module and original_node.module.value == "typing":
new_names = updated_node.names + (cst.ImportAlias(name=cst.Name("Final")),) # type: ignore[operator]
return updated_node.with_changes(names=new_names)
return super().leave_ImportFrom(original_node, updated_node)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef):
if self.handled:
return super().leave_ClassDef(original_node, updated_node)
self.handled = True
const = parse_template_statement(
"API_VERSION: Final[str] = {api_version}",
api_version=cst.SimpleString(f'"{self.api_version}"'),
)
return cst.FlattenSentinel([const, updated_node])
# Remove Task class that represent a tuple of (task_id, map_index)
# for `TISkippedDownstreamTasksStatePayload`
class ModifyTasksAnnotation(cst.CSTTransformer):
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement] | cst.RemovalSentinel:
if original_node.name.value == "Tasks":
return cst.RemoveFromParent()
return super().leave_ClassDef(original_node, updated_node)
def leave_AnnAssign(
self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
) -> cst.AnnAssign | cst.RemovalSentinel:
"""
Replaces `tasks: Annotated[list[str | Tasks], Field(title="Tasks")]`
with `tasks: Annotated[list[str | tuple[str, int]], Field(title="Tasks")]`
only if inside `TISkippedDownstreamTasksStatePayload`.
"""
# Check if the target is 'tasks'
if not isinstance(updated_node.target, cst.Name) or updated_node.target.value != "tasks":
return updated_node
if not isinstance(updated_node.annotation, cst.Annotation):
return updated_node
# Create a replacement for 'Tasks' -> 'tuple[str, int]'
tuple_type = cst.Subscript(
value=cst.Name("tuple"),
slice=[
cst.SubscriptElement(cst.Index(cst.Name("str"))),
cst.SubscriptElement(cst.Index(cst.Name("int"))),
],
)
# Transformer to replace all instances of 'Tasks' with 'tuple[str, int]'
class TasksReplacer(cst.CSTTransformer):
def leave_Name(
self, original_node: cst.Name, updated_node: cst.Name
) -> cst.BaseExpression:
if original_node.value == "Tasks":
return tuple_type
return updated_node
# Apply the transformation to the annotation part only
new_annotation = updated_node.annotation.visit(TasksReplacer())
return updated_node.with_changes(annotation=new_annotation)
source_tree = cst.parse_module(code)
modified_tree = source_tree.visit(JsonValueNodeRemover())
if api_version := self.formatter_kwargs.get("api_version"):
modified_tree = modified_tree.visit(VersionConstInjtector(api_version))
modified_tree = modified_tree.visit(ModifyTasksAnnotation())
code = modified_tree.code
result = subprocess.check_output(
["ruff", "check", "--fix-only", "--unsafe-fixes", "--quiet", "--preview", "-"],
input=code,
text=True,
)
return result