blob: 8a3ad52e1f6f96dc7e707426cc8d511e33df8a78 [file] [log] [blame]
#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Module to update db migration information in Airflow
"""
from __future__ import annotations
import os
import re
import textwrap
from pathlib import Path
from typing import TYPE_CHECKING, Iterable
from alembic.script import ScriptDirectory
from rich.console import Console
from tabulate import tabulate
from airflow import __version__ as airflow_version
from airflow.utils.db import _get_alembic_config
if TYPE_CHECKING:
from alembic.script import Script
console = Console(width=400, color_system="standard")
airflow_version = re.match(r"(\d+\.\d+\.\d+).*", airflow_version).group(1) # type: ignore
project_root = Path(__file__).parents[2].resolve()
def replace_text_between(file: Path, start: str, end: str, replacement_text: str):
original_text = file.read_text()
leading_text = original_text.split(start)[0]
trailing_text = original_text.split(end)[1]
file.write_text(leading_text + start + replacement_text + end + trailing_text)
def wrap_backticks(val):
def _wrap_backticks(x):
return f"``{x}``"
return ",\n".join(map(_wrap_backticks, val)) if isinstance(val, (tuple, list)) else _wrap_backticks(val)
def update_doc(file, data):
replace_text_between(
file=file,
start=" .. Beginning of auto-generated table\n",
end=" .. End of auto-generated table\n",
replacement_text="\n"
+ tabulate(
headers={
"revision": "Revision ID",
"down_revision": "Revises ID",
"version": "Airflow Version",
"description": "Description",
},
tabular_data=data,
tablefmt="grid",
stralign="left",
disable_numparse=True,
)
+ "\n\n",
)
def has_version(content):
return re.search(r"^airflow_version\s*=.*", content, flags=re.MULTILINE) is not None
def insert_version(old_content, file):
new_content = re.sub(
r"(^depends_on.*)",
lambda x: f"{x.group(1)}\nairflow_version = '{airflow_version}'",
old_content,
flags=re.MULTILINE,
)
file.write_text(new_content)
def revision_suffix(rev: Script):
if rev.is_head:
return " (head)"
if rev.is_base:
return " (base)"
if rev.is_merge_point:
return " (merge_point)"
if rev.is_branch_point:
return " (branch_point)"
return ""
def ensure_airflow_version(revisions: Iterable[Script]):
for rev in revisions:
if TYPE_CHECKING: # For mypy
assert rev.module.__file__ is not None
file = Path(rev.module.__file__)
content = file.read_text()
if not has_version(content):
insert_version(content, file)
def get_revisions() -> Iterable[Script]:
config = _get_alembic_config()
script = ScriptDirectory.from_config(config)
yield from script.walk_revisions()
def update_docs(revisions: Iterable[Script]):
doc_data = []
for rev in revisions:
doc_data.append(
dict(
revision=wrap_backticks(rev.revision) + revision_suffix(rev),
down_revision=wrap_backticks(rev.down_revision),
version=wrap_backticks(rev.module.airflow_version), # type: ignore
description="\n".join(textwrap.wrap(rev.doc, width=60)),
)
)
update_doc(
file=project_root / "docs" / "apache-airflow" / "migrations-ref.rst",
data=doc_data,
)
def ensure_mod_prefix(mod_name, idx, version):
parts = [f"{idx + 1:04}", *version]
match = re.match(r"([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)_(.+)", mod_name)
if match:
# previously standardized file, rebuild the name
parts.append(match.group(5))
else:
# new migration file, standard format
match = re.match(r"([a-z0-9]+)_(.+)", mod_name)
if match:
parts.append(match.group(2))
return "_".join(parts)
def ensure_filenames_are_sorted(revisions):
renames = []
is_branched = False
unmerged_heads = []
for idx, rev in enumerate(revisions):
mod_path = Path(rev.module.__file__)
version = rev.module.airflow_version.split(".")[0:3] # only first 3 tokens
correct_mod_basename = ensure_mod_prefix(mod_path.name, idx, version)
if mod_path.name != correct_mod_basename:
renames.append((mod_path, Path(mod_path.parent, correct_mod_basename)))
if is_branched and rev.is_merge_point:
is_branched = False
if rev.is_branch_point:
is_branched = True
elif rev.is_head:
unmerged_heads.append(rev.revision)
if is_branched:
head_prefixes = [x[0:4] for x in unmerged_heads]
alembic_command = (
"alembic merge -m 'merge heads " + ", ".join(head_prefixes) + "' " + " ".join(unmerged_heads)
)
raise SystemExit(
"You have multiple alembic heads; please merge them with by running `alembic merge` command under "
f'"airflow" directory (where alembic.ini located) and re-run pre-commit. '
f"It should fail once more before succeeding.\nhint: `{alembic_command}`"
)
for old, new in renames:
os.rename(old, new)
if __name__ == "__main__":
console.print("[bright_blue]Updating migration reference")
revisions = list(reversed(list(get_revisions())))
console.print("[bright_blue]Making sure airflow version updated")
ensure_airflow_version(revisions=revisions)
revisions = list(reversed(list(get_revisions())))
console.print("[bright_blue]Making sure filenames are sorted")
ensure_filenames_are_sorted(revisions=revisions)
revisions = list(get_revisions())
console.print("[bright_blue]Updating documentation")
update_docs(revisions=revisions)
console.print("[green]Migrations OK")