blob: 1d78b9d7242e2ebddf3a1fed6fdfe96c559a3429 [file] [log] [blame]
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# /// script
# requires-python = ">=3.10,<3.11"
# dependencies = [
# "rich>=13.6.0",
# ]
# ///
from __future__ import annotations
import ast
import pathlib
import re
import sys
import typing
sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve())) # make sure common_prek_utils is imported
from common_prek_utils import AIRFLOW_CORE_ROOT_PATH, AIRFLOW_TASK_SDK_SOURCES_PATH
TASKRUNNER_PY = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "execution_time" / "task_runner.py"
CONTEXT_HINT = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "context.py"
TEMPLATES_REF_RST = AIRFLOW_CORE_ROOT_PATH / "docs" / "templates-ref.rst"
# These are only conditionally set
IGNORE = {
"ds",
"ds_nodash",
"ts",
"ts_nodash",
"ts_nodash_with_tz",
"logical_date",
"data_interval_end",
"data_interval_start",
"prev_data_interval_start_success",
"prev_data_interval_end_success",
}
def _iter_template_context_keys_from_original_return() -> typing.Iterator[str]:
ti_mod = ast.parse(TASKRUNNER_PY.read_text("utf-8"), str(TASKRUNNER_PY))
# Locate the RuntimeTaskInstance class definition
runtime_task_instance_class = next(
node
for node in ast.iter_child_nodes(ti_mod)
if isinstance(node, ast.ClassDef) and node.name == "RuntimeTaskInstance"
)
# Locate the get_template_context method in RuntimeTaskInstance
fn_get_template_context = next(
node
for node in ast.iter_child_nodes(runtime_task_instance_class)
if isinstance(node, ast.FunctionDef) and node.name == "get_template_context"
)
# Helper function to extract keys from a dictionary node
def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
for key in node.keys:
if not isinstance(key, ast.Constant) or not isinstance(key.value, str):
raise ValueError("Key in dictionary is not a string literal")
yield key.value
# Extract keys from the main `context` dictionary assignment
context_assignment: ast.AnnAssign = next(
stmt
for stmt in fn_get_template_context.body
if isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Attribute)
and isinstance(stmt.target.value, ast.Name)
and stmt.target.value.id == "self"
and stmt.target.attr == "_cached_template_context"
)
if not isinstance(context_assignment.value, ast.BoolOp):
raise TypeError("Expected a BoolOp like 'self._cached_template_context or {...}'.")
context_assignment_op = context_assignment.value
_, context_assignment_value = context_assignment_op.values
if not isinstance(context_assignment_value, ast.Dict):
raise ValueError("'context' is not assigned a dictionary literal")
yield from extract_keys_from_dict(context_assignment_value)
# Handle keys added conditionally in `if from_server`
for stmt in fn_get_template_context.body:
if isinstance(stmt, ast.If) and isinstance(stmt.test, ast.Name) and stmt.test.id == "from_server":
for sub_stmt in stmt.body:
# Get keys from `context_from_server` assignment
if (
isinstance(sub_stmt, ast.AnnAssign)
and isinstance(sub_stmt.target, ast.Name)
and isinstance(sub_stmt.value, ast.Dict)
and sub_stmt.target.id == "context_from_server"
):
yield from extract_keys_from_dict(sub_stmt.value)
def _iter_template_context_keys_from_type_hints() -> typing.Iterator[str]:
context_mod = ast.parse(CONTEXT_HINT.read_text("utf-8"), str(CONTEXT_HINT))
cls_context = next(
node
for node in ast.iter_child_nodes(context_mod)
if isinstance(node, ast.ClassDef) and node.name == "Context"
)
for stmt in cls_context.body:
if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant):
# Skip docstring
continue
if not isinstance(stmt, ast.AnnAssign) or not isinstance(stmt.target, ast.Name):
raise ValueError("key in 'Context' hint is not an annotated assignment")
yield stmt.target.id
def _iter_template_context_keys_from_documentation() -> typing.Iterator[str]:
# We can use docutils to actually parse, but regex is good enough for now.
# This should find names in the "Variable" and "Deprecated Variable" tables.
content = TEMPLATES_REF_RST.read_text("utf-8")
for match in re.finditer(r"^``{{ (?P<name>\w+)(?P<subname>\.\w+)* }}`` ", content, re.MULTILINE):
yield match.group("name")
def _compare_keys(retn_keys: set[str], hint_keys: set[str], docs_keys: set[str]) -> int:
# Added by PythonOperator and commonly used.
# Not listed in templates-ref (but in operator docs).
retn_keys.add("templates_dict")
docs_keys.add("templates_dict")
# Compat shim for task-sdk, not actually designed for user use
retn_keys.add("expanded_ti_count")
# TODO: These are the keys that are yet to be ported over to the Task SDK.
retn_keys.add("test_mode")
# Only present in callbacks. Not listed in templates-ref (that doc is for task execution).
retn_keys.update(("exception", "reason", "try_number"))
docs_keys.update(("exception", "reason", "try_number"))
# Airflow 3 added:
retn_keys.update(("start_date", "task_reschedule_count"))
check_candidates = [
("get_template_context()", retn_keys),
("Context type hint", hint_keys),
("templates-ref", docs_keys),
]
canonical_keys = set.union(*(s for _, s in check_candidates)) - IGNORE
def _check_one(identifier: str, keys: set[str]) -> int:
if missing := canonical_keys.difference(keys):
print("Missing template variables from", f"{identifier}:", ", ".join(sorted(missing)))
return len(missing)
return sum(_check_one(identifier, keys) for identifier, keys in check_candidates)
def main() -> str | int | None:
retn_keys = set(_iter_template_context_keys_from_original_return())
hint_keys = set(_iter_template_context_keys_from_type_hints())
docs_keys = set(_iter_template_context_keys_from_documentation())
return _compare_keys(retn_keys, hint_keys, docs_keys)
if __name__ == "__main__":
sys.exit(main())