blob: 3559e523cf09d374c8ac2ad9021c01a427de48db [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import ast
import difflib
import os
import shlex
import shutil
import subprocess
import sys
import textwrap
import time
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import Any
AIRFLOW_ROOT_PATH = Path(__file__).parents[3].resolve()
AIRFLOW_CORE_ROOT_PATH = AIRFLOW_ROOT_PATH / "airflow-core"
AIRFLOW_CORE_SOURCES_PATH = AIRFLOW_CORE_ROOT_PATH / "src"
AIRFLOW_BREEZE_SOURCES_PATH = AIRFLOW_ROOT_PATH / "dev" / "breeze"
AIRFLOW_PROVIDERS_ROOT_PATH = AIRFLOW_ROOT_PATH / "providers"
AIRFLOW_TASK_SDK_ROOT_PATH = AIRFLOW_ROOT_PATH / "task-sdk"
AIRFLOW_TASK_SDK_SOURCES_PATH = AIRFLOW_TASK_SDK_ROOT_PATH / "src"
# Here we should add the second level paths that we want to have sub-packages in
KNOWN_SECOND_LEVEL_PATHS = ["apache", "atlassian", "common", "cncf", "dbt", "microsoft"]
DEFAULT_PYTHON_MAJOR_MINOR_VERSION = "3.10"
GITHUB_TOKEN: str | None = os.environ.get("GITHUB_TOKEN")
try:
from rich.console import Console
console = Console(width=400, color_system="standard")
except ImportError:
console = None # type: ignore[assignment]
@contextmanager
def temporary_tsc_project(
tsconfig_path: Path, files: list[str]
) -> Generator[_TemporaryFileWrapper, None, None]:
"""
Create a temporary tsconfig.json file that extends the main tsconfig.json file.
This is needed to run TypeScript compiler with specific files included only
"""
if not tsconfig_path.exists():
raise RuntimeError(f"Cannot find {tsconfig_path}")
temp_tsconfig_path = NamedTemporaryFile(mode="wt", suffix=".json", dir=tsconfig_path.parent, delete=True)
files_joined = ", ".join([f'"{file}"' for file in files])
content = f'{{"extends": "./{tsconfig_path.name}", "include": [{files_joined}]}}'
if console:
console.print(f"[magenta]Creating temporary tsconfig.json at {temp_tsconfig_path.name}[/]")
console.print(content)
else:
print(f"Creating temporary tsconfig.json at {temp_tsconfig_path.name}", file=sys.stderr)
print(content, file=sys.stderr)
temp_tsconfig_path.write(content)
temp_tsconfig_path.flush()
yield temp_tsconfig_path
def run_command(*args, **kwargs) -> None:
"""
Run command with given arguments and return the result.
"""
cmd = " ".join([shlex.quote(arg) for arg in args[0]])
cwd = kwargs.get("cwd", os.getcwd())
text = f"Running command: `{cmd}` in directory: `{cwd}`"
if console:
console.print(f"[magenta]{text}[/]")
else:
print("#" * min(len(text), 200), file=sys.stderr)
print(text, file=sys.stderr)
print("#" * min(len(text), 200), file=sys.stderr)
time_start = time.time()
subprocess.check_call(*args, **kwargs)
time_end = time.time()
if console:
console.print(f"[green]After {text}[/]")
console.print(f"[green]Command finished in {time_end - time_start:.2f} seconds[/]")
else:
print("#" * min(len(text), 200), file=sys.stderr)
print(f"After {text}")
print(f"Command finished in {time_end - time_start:.2f} seconds", file=sys.stderr)
print("#" * min(len(text), 200), file=sys.stderr)
def read_airflow_version() -> str:
ast_obj = ast.parse((AIRFLOW_CORE_SOURCES_PATH / "airflow" / "__init__.py").read_text())
for node in ast_obj.body:
if isinstance(node, ast.Assign):
if node.targets[0].id == "__version__": # type: ignore[attr-defined]
return ast.literal_eval(node.value)
raise RuntimeError("Couldn't find __version__ in AST")
GLOBAL_CONSTANTS_PATH = (
AIRFLOW_ROOT_PATH / "dev" / "breeze" / "src" / "airflow_breeze" / "global_constants.py"
)
def _read_global_constants_assignment(name: str) -> Any:
"""Read a top-level assignment from global_constants.py."""
tree = ast.parse(GLOBAL_CONSTANTS_PATH.read_text())
for node in tree.body:
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == name:
return ast.literal_eval(node.value)
raise RuntimeError(f"{name} not found in global_constants.py")
def read_allowed_kubernetes_versions() -> list[str]:
"""Parse ALLOWED_KUBERNETES_VERSIONS from global_constants.py (single source of truth).
Returns versions without the ``v`` prefix, e.g. ``["1.30.13", "1.31.12", ...]``.
"""
versions: list[str] = _read_global_constants_assignment("ALLOWED_KUBERNETES_VERSIONS")
return [v.lstrip("v") for v in versions]
def read_default_python_major_minor_version_for_images() -> str:
"""Parse DEFAULT_PYTHON_MAJOR_MINOR_VERSION_FOR_IMAGES from global_constants.py."""
value = _read_global_constants_assignment("DEFAULT_PYTHON_MAJOR_MINOR_VERSION_FOR_IMAGES")
if not isinstance(value, str):
raise RuntimeError(
"DEFAULT_PYTHON_MAJOR_MINOR_VERSION_FOR_IMAGES in global_constants.py "
f"must be a string, got {type(value).__name__}"
)
return value
def pre_process_mypy_files(files: list[str]) -> list[str]:
"""Pre-process files passed to mypy.
* Exclude conftest.py files and __init__.py files
* When running build on non-main branch do not take providers into account.
* When running "airflow-core" package, then we need to exclude providers.
"""
files = [file for file in files if not file.endswith("conftest.py") and not file.endswith("__init__.py")]
default_branch = os.environ.get("DEFAULT_BRANCH")
if not default_branch or default_branch == "main":
return files
return [file for file in files if not file.startswith("providers")]
def insert_documentation(
file_path: Path,
content: list[str],
header: str,
footer: str,
add_comment: bool = False,
extra_information: str | None = None,
) -> bool:
found = False
old_content = file_path.read_text()
lines = old_content.splitlines(keepends=True)
replacing = False
result: list[str] = []
for line in lines:
if line.strip().startswith(header.strip()):
replacing = True
found = True
result.append(line)
if add_comment:
result.extend(["# " + line if line != "\n" else "#\n" for line in content])
else:
result.extend(content)
if line.strip().startswith(footer.strip()):
replacing = False
if not replacing:
result.append(line)
new_content = "".join(result)
if not found:
print(f"Header {header} not found in {file_path}")
sys.exit(1)
if new_content != old_content:
file_path.write_text(new_content)
console.print(f"Updated {file_path} with {extra_information or 'generated documentation'}")
return True
return False
def initialize_breeze_prek(name: str, file: str):
if name not in ("__main__", "__mp_main__"):
raise SystemExit(
"This file is intended to be executed as an executable program. You cannot use it as a module."
f"To run this script, run the ./{file} command"
)
if os.environ.get("SKIP_BREEZE_PREK_HOOKS"):
console.print("[yellow]Skipping breeze prek hooks as SKIP_BREEZE_PREK_HOOKS is set")
sys.exit(0)
if shutil.which("breeze") is None:
console.print(
"[red]The `breeze` command is not on path.[/]\n\n"
"[yellow]Please install breeze.\n"
"You can use uv with `uv tool install -e ./dev/breeze or "
"`pipx install -e ./dev/breeze`.\n"
"It will install breeze from Airflow sources "
"(make sure you run `pipx ensurepath` if you use pipx)[/]\n\n"
"[bright_blue]You can also set SKIP_BREEZE_PREK_HOOKS env variable to non-empty "
"value to skip all breeze tests."
)
sys.exit(1)
def run_command_via_breeze_shell(
cmd: list[str],
python_version: str = DEFAULT_PYTHON_MAJOR_MINOR_VERSION,
backend: str = "none",
executor: str = "LocalExecutor",
extra_env: dict[str, str] | None = None,
project_name: str = "prek",
skip_environment_initialization: bool = True,
warn_image_upgrade_needed: bool = False,
enable_pseudo_terminal: bool = False,
**other_popen_kwargs,
) -> subprocess.CompletedProcess:
extra_env = extra_env or {}
subprocess_cmd: list[str] = [
"breeze",
"shell",
"--python",
python_version,
"--backend",
backend,
"--executor",
executor,
"--quiet",
"--restart",
"--skip-image-upgrade-check",
# Note: The terminal is disabled - because prek is run inside git without a pseudo-terminal
"--tty",
"enabled" if enable_pseudo_terminal else "disabled",
]
if warn_image_upgrade_needed:
subprocess_cmd.append("--warn-image-upgrade-needed")
if skip_environment_initialization:
subprocess_cmd.append("--skip-environment-initialization")
if project_name:
subprocess_cmd.extend(["--project-name", project_name])
subprocess_cmd.append(" ".join([shlex.quote(arg) for arg in cmd]))
new_env = {
**os.environ,
"SKIP_BREEZE_SELF_UPGRADE_CHECK": "true",
"SKIP_GROUP_OUTPUT": "true",
"SKIP_SAVING_CHOICES": "true",
"ANSWER": "no",
**extra_env,
}
if os.environ.get("VERBOSE_COMMANDS") or os.environ.get("CI") == "true":
if console:
console.print(
f"[magenta]Running command: {' '.join([shlex.quote(item) for item in subprocess_cmd])}[/]"
)
console.print("[magenta]With environment:[/]")
console.print(new_env)
else:
print(f"Running command: {' '.join([shlex.quote(item) for item in subprocess_cmd])}")
print("With environment:")
print(new_env)
result = subprocess.run(
subprocess_cmd,
check=False,
text=True,
**other_popen_kwargs,
env=new_env,
)
# Stop remaining containers
down_command = ["docker", "compose", "--progress", "quiet"]
if project_name:
down_command.extend(["--project-name", project_name])
down_command.extend(["down", "--remove-orphans"])
subprocess.run(down_command, check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return result
class ConsoleDiff(difflib.Differ):
def _dump(self, tag, x, lo, hi):
"""Generate comparison results for a same-tagged range."""
for i in range(lo, hi):
if tag == "+":
yield f"[green]{tag} {x[i]}[/]"
elif tag == "-":
yield f"[red]{tag} {x[i]}[/]"
else:
yield f"{tag} {x[i]}"
def check_list_sorted(the_list: list[str], message: str, errors: list[str]) -> bool:
sorted_list = sorted(set(the_list))
if the_list == sorted_list:
console.print(f"{message} is [green]ok[/]")
console.print(the_list)
console.print()
return True
console.print(f"{message} [red]NOK[/]")
console.print(textwrap.indent("\n".join(ConsoleDiff().compare(the_list, sorted_list)), " " * 4))
console.print()
errors.append(f"ERROR in {message}. The elements are not sorted/unique.")
return False
def validate_cmd_result(cmd_result, include_ci_env_check=False):
if include_ci_env_check:
if cmd_result.returncode != 0 and os.environ.get("CI") != "true":
if console:
console.print(
"\n[yellow]If you see strange stacktraces above, especially about missing imports "
"run this command:[/]\n"
)
console.print(
"[magenta]breeze ci-image build --python 3.10 --upgrade-to-newer-dependencies[/]\n"
)
else:
print(
"\nIf you see strange stacktraces above, especially about missing imports "
"run this command:\nbreeze ci-image build --python 3.10 --upgrade-to-newer-dependencies\n"
)
elif cmd_result.returncode != 0:
if console:
console.print(
"[warning]\nIf you see strange stacktraces above, "
"run `breeze ci-image build --python 3.10` and try again."
)
else:
print(
"\nIf you see strange stacktraces above, "
"run `breeze ci-image build --python 3.10` and try again."
)
sys.exit(cmd_result.returncode)
def get_provider_id_from_path(file_path: Path) -> str | None:
"""
Get the provider id from the path of the file it belongs to.
"""
for parent in file_path.parents:
# This works fine for both new and old providers structure - because we moved provider.yaml to
# the top-level of the provider and this code finding "providers" will find the "providers" package
# in old structure and "providers" directory in new structure - in both cases we can determine
# the provider id from the relative folders
if (parent / "provider.yaml").exists():
for providers_root_candidate in parent.parents:
if providers_root_candidate.name == "providers":
return parent.relative_to(providers_root_candidate).as_posix().replace("/", ".")
return None
return None
def get_provider_base_dir_from_path(file_path: Path) -> Path | None:
"""
Get the provider base dir (where provider.yaml is) from the path of the file it belongs to.
"""
for parent in file_path.parents:
if (parent / "provider.yaml").exists():
return parent
return None
def get_all_provider_ids(exclude_suspended_providers: bool = False) -> list[str]:
"""
Get all providers from the new provider structure
"""
all_provider_ids = []
for provider_file in AIRFLOW_PROVIDERS_ROOT_PATH.rglob("provider.yaml"):
if provider_file.is_relative_to(AIRFLOW_PROVIDERS_ROOT_PATH / "src"):
continue
if exclude_suspended_providers:
import yaml
provider_info = yaml.safe_load(provider_file.read_text())
if provider_info.get("state") == "suspended":
continue
provider_id = get_provider_id_from_path(provider_file)
if provider_id:
all_provider_ids.append(provider_id)
return all_provider_ids
def get_all_provider_yaml_files() -> list[Path]:
"""
Get all providers from the new provider structure
"""
all_provider_yaml_files = []
for provider_file in AIRFLOW_PROVIDERS_ROOT_PATH.rglob("provider.yaml"):
if provider_file.is_relative_to(AIRFLOW_PROVIDERS_ROOT_PATH / "src"):
continue
all_provider_yaml_files.append(provider_file)
return all_provider_yaml_files
def get_all_provider_info_dicts() -> dict[str, dict]:
"""
Get provider yaml info for all providers from the new provider structure
"""
providers: dict[str, dict] = {}
for provider_file in get_all_provider_yaml_files():
provider_id = str(provider_file.parent.relative_to(AIRFLOW_PROVIDERS_ROOT_PATH)).replace(os.sep, ".")
import yaml
provider_info = yaml.safe_load(provider_file.read_text())
if provider_info["state"] != "suspended":
providers[provider_id] = provider_info
return providers
def get_imports_from_file(file_path: Path, *, only_top_level: bool) -> list[str]:
"""
Returns list of all imports in file.
For following code:
import os
from collections import defaultdict
import numpy as np
from pandas import DataFrame as DF
def inner():
import json
from pathlib import Path, PurePath
from __future__ import annotations
When only_top_level = False then returns
['os', 'collections.defaultdict', 'numpy', 'pandas.DataFrame']
When only_top_level = False then returns
['os', 'collections.defaultdict', 'numpy', 'pandas.DataFrame', 'json', 'pathlib.Path', 'pathlib.PurePath']
"""
root = ast.parse(file_path.read_text(), file_path.name)
imports: list[str] = []
nodes = ast.iter_child_nodes(root) if only_top_level else ast.walk(root)
for node in nodes:
if isinstance(node, ast.Import):
for alias in node.names:
imports.append(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module == "__future__":
continue
for alias in node.names:
name = alias.name
fullname = f"{node.module}.{name}" if node.module else name
imports.append(fullname)
return imports
def get_remote_for_main() -> str:
"""
Return the remote name to use when fetching main.
Prefers the remote that points to apache/airflow; otherwise uses origin.
"""
result = subprocess.run(
["git", "remote", "-v"],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
return "origin"
apache_remote = None
origin_remote = None
for line in result.stdout.splitlines():
parts = line.split()
if len(parts) >= 2:
name, url = parts[0], parts[1]
if "apache/airflow" in url:
apache_remote = name
break
if name == "origin":
origin_remote = name
return apache_remote or origin_remote or "origin"
def retrieve_gh_token(*, token: str | None = None, description: str, scopes: str) -> str:
if token:
return token
if GITHUB_TOKEN:
return GITHUB_TOKEN
output = subprocess.check_output(["gh", "auth", "token"])
token = output.decode().strip()
if not token:
if not console:
raise RuntimeError("Please add rich to your script dependencies and run it again")
console.print(
"[red]GITHUB_TOKEN environment variable is not set. "
"This might lead to failures on rate limits.[/]\n"
"You can fix that by installing `gh` and running `gh auth login` or "
f"set it to a valid GitHub token with {scopes} scope. "
f"You can create one by clicking the URL:\n\n"
f"https://github.com/settings/tokens/new?scopes={scopes}&description={description}\n\n"
"Once you have the token you can prepend prek command with GITHUB_TOKEN='<your token>' or"
"set it in your environment with export GITHUB_TOKEN='<your token>'\n\n"
)
sys.exit(1)
return token
def parse_operations(
operations_file: Path, exclude_operation_classes: set, exclude_methods: set
) -> dict[str, list[str]]:
"""Parse airflowctl operations file and return a mapping of CLI group names to subcommands."""
commands: dict[str, list[str]] = {}
with open(operations_file) as f:
tree = ast.parse(f.read(), filename=str(operations_file))
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name.endswith("Operations"):
if node.name in exclude_operation_classes:
continue
group_name = node.name.replace("Operations", "").lower()
commands[group_name] = []
for child in node.body:
if isinstance(child, ast.FunctionDef):
method_name = child.name
if method_name in exclude_methods or method_name.startswith("_"):
continue
subcommand = method_name.replace("_", "-")
commands[group_name].append(subcommand)
return commands