blob: fa8696e8f72c12501f78143f1f6ce1c0fe30aed8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module took inspiration from the community maintenance dag
(https://github.com/teamclairvoyant/airflow-maintenance-dags/blob/4e5c7682a808082561d60cbc9cafaa477b0d8c65/db-cleanup/airflow-db-cleanup.py).
"""
from __future__ import annotations
import csv
import logging
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
from pendulum import DateTime
from sqlalchemy import and_, column, false, func, inspect, table, text
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Query, Session, aliased
from sqlalchemy.sql.expression import ClauseElement, Executable, tuple_
from airflow import AirflowException
from airflow.cli.simple_table import AirflowConsole
from airflow.models import Base
from airflow.utils import timezone
from airflow.utils.db import reflect_tables
from airflow.utils.helpers import ask_yesno
from airflow.utils.session import NEW_SESSION, provide_session
logger = logging.getLogger(__file__)
ARCHIVE_TABLE_PREFIX = "_airflow_deleted__"
@dataclass
class _TableConfig:
"""
Config class for performing cleanup on a table.
:param table_name: the table
:param extra_columns: any columns besides recency_column_name that we'll need in queries
:param recency_column_name: date column to filter by
:param keep_last: whether the last record should be kept even if it's older than clean_before_timestamp
:param keep_last_filters: the "keep last" functionality will preserve the most recent record
in the table. to ignore certain records even if they are the latest in the table, you can
supply additional filters here (e.g. externally triggered dag runs)
:param keep_last_group_by: if keeping the last record, can keep the last record for each group
"""
table_name: str
recency_column_name: str
extra_columns: list[str] | None = None
keep_last: bool = False
keep_last_filters: Any | None = None
keep_last_group_by: Any | None = None
def __post_init__(self):
self.recency_column = column(self.recency_column_name)
self.orm_model: Base = table(
self.table_name, *[column(x) for x in self.extra_columns or []], self.recency_column
)
def __lt__(self, other):
return self.table_name < other.table_name
@property
def readable_config(self):
return dict(
table=self.orm_model.name,
recency_column=str(self.recency_column),
keep_last=self.keep_last,
keep_last_filters=[str(x) for x in self.keep_last_filters] if self.keep_last_filters else None,
keep_last_group_by=str(self.keep_last_group_by),
)
config_list: list[_TableConfig] = [
_TableConfig(table_name="job", recency_column_name="latest_heartbeat"),
_TableConfig(table_name="dag", recency_column_name="last_parsed_time"),
_TableConfig(
table_name="dag_run",
recency_column_name="start_date",
extra_columns=["dag_id", "external_trigger"],
keep_last=True,
keep_last_filters=[column("external_trigger") == false()],
keep_last_group_by=["dag_id"],
),
_TableConfig(table_name="dataset_event", recency_column_name="timestamp"),
_TableConfig(table_name="import_error", recency_column_name="timestamp"),
_TableConfig(table_name="log", recency_column_name="dttm"),
_TableConfig(table_name="sla_miss", recency_column_name="timestamp"),
_TableConfig(table_name="task_fail", recency_column_name="start_date"),
_TableConfig(table_name="task_instance", recency_column_name="start_date"),
_TableConfig(table_name="task_reschedule", recency_column_name="start_date"),
_TableConfig(table_name="xcom", recency_column_name="timestamp"),
_TableConfig(table_name="callback_request", recency_column_name="created_at"),
_TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"),
_TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"),
]
config_dict: dict[str, _TableConfig] = {x.orm_model.name: x for x in sorted(config_list)}
def _check_for_rows(*, query: Query, print_rows=False):
num_entities = query.count()
print(f"Found {num_entities} rows meeting deletion criteria.")
if print_rows:
max_rows_to_print = 100
if num_entities > 0:
print(f"Printing first {max_rows_to_print} rows.")
logger.debug("print entities query: %s", query)
for entry in query.limit(max_rows_to_print):
print(entry.__dict__)
return num_entities
def _dump_table_to_file(*, target_table, file_path, export_format, session):
if export_format == "csv":
with open(file_path, "w") as f:
csv_writer = csv.writer(f)
cursor = session.execute(text(f"SELECT * FROM {target_table}"))
csv_writer.writerow(cursor.keys())
csv_writer.writerows(cursor.fetchall())
else:
raise AirflowException(f"Export format {export_format} is not supported.")
def _do_delete(*, query, orm_model, skip_archive, session):
import re
from datetime import datetime
print("Performing Delete...")
# using bulk delete
# create a new table and copy the rows there
timestamp_str = re.sub(r"[^\d]", "", datetime.utcnow().isoformat())[:14]
target_table_name = f"{ARCHIVE_TABLE_PREFIX}{orm_model.name}__{timestamp_str}"
print(f"Moving data to table {target_table_name}")
bind = session.get_bind()
dialect_name = bind.dialect.name
if dialect_name == "mysql":
# MySQL with replication needs this split into two queries, so just do it for all MySQL
# ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT.
session.execute(f"CREATE TABLE {target_table_name} LIKE {orm_model.name}")
metadata = reflect_tables([target_table_name], session)
target_table = metadata.tables[target_table_name]
insert_stm = target_table.insert().from_select(target_table.c, query)
logger.debug("insert statement:\n%s", insert_stm.compile())
session.execute(insert_stm)
else:
stmt = CreateTableAs(target_table_name, query.selectable)
logger.debug("ctas query:\n%s", stmt.compile())
session.execute(stmt)
session.commit()
# delete the rows from the old table
metadata = reflect_tables([orm_model.name, target_table_name], session)
source_table = metadata.tables[orm_model.name]
target_table = metadata.tables[target_table_name]
logger.debug("rows moved; purging from %s", source_table.name)
if dialect_name == "sqlite":
pk_cols = source_table.primary_key.columns
delete = source_table.delete().where(
tuple_(*pk_cols).in_(
session.query(*[target_table.c[x.name] for x in source_table.primary_key.columns]).subquery()
)
)
else:
delete = source_table.delete().where(
and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
)
logger.debug("delete statement:\n%s", delete.compile())
session.execute(delete)
session.commit()
if skip_archive:
target_table.drop()
session.commit()
print("Finished Performing Delete")
def _subquery_keep_last(*, recency_column, keep_last_filters, group_by_columns, max_date_colname, session):
subquery = session.query(*group_by_columns, func.max(recency_column).label(max_date_colname))
if keep_last_filters is not None:
for entry in keep_last_filters:
subquery = subquery.filter(entry)
if group_by_columns is not None:
subquery = subquery.group_by(*group_by_columns)
return subquery.subquery(name="latest")
class CreateTableAs(Executable, ClauseElement):
"""Custom sqlalchemy clause element for CTAS operations."""
def __init__(self, name, query):
self.name = name
self.query = query
@compiles(CreateTableAs)
def _compile_create_table_as__other(element, compiler, **kw):
return f"CREATE TABLE {element.name} AS {compiler.process(element.query)}"
@compiles(CreateTableAs, "mssql")
def _compile_create_table_as__mssql(element, compiler, **kw):
return f"WITH cte AS ( {compiler.process(element.query)} ) SELECT * INTO {element.name} FROM cte"
def _build_query(
*,
orm_model,
recency_column,
keep_last,
keep_last_filters,
keep_last_group_by,
clean_before_timestamp,
session,
**kwargs,
):
base_table_alias = "base"
base_table = aliased(orm_model, name=base_table_alias)
query = session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
base_table_recency_col = base_table.c[recency_column.name]
conditions = [base_table_recency_col < clean_before_timestamp]
if keep_last:
max_date_col_name = "max_date_per_group"
group_by_columns = [column(x) for x in keep_last_group_by]
subquery = _subquery_keep_last(
recency_column=recency_column,
keep_last_filters=keep_last_filters,
group_by_columns=group_by_columns,
max_date_colname=max_date_col_name,
session=session,
)
query = query.select_from(base_table).outerjoin(
subquery,
and_(
*[base_table.c[x] == subquery.c[x] for x in keep_last_group_by],
base_table_recency_col == column(max_date_col_name),
),
)
conditions.append(column(max_date_col_name).is_(None))
query = query.filter(and_(*conditions))
return query
def _cleanup_table(
*,
orm_model,
recency_column,
keep_last,
keep_last_filters,
keep_last_group_by,
clean_before_timestamp,
dry_run=True,
verbose=False,
skip_archive=False,
session,
**kwargs,
):
print()
if dry_run:
print(f"Performing dry run for table {orm_model.name}")
query = _build_query(
orm_model=orm_model,
recency_column=recency_column,
keep_last=keep_last,
keep_last_filters=keep_last_filters,
keep_last_group_by=keep_last_group_by,
clean_before_timestamp=clean_before_timestamp,
session=session,
)
logger.debug("old rows query:\n%s", query.selectable.compile())
print(f"Checking table {orm_model.name}")
num_rows = _check_for_rows(query=query, print_rows=False)
if num_rows and not dry_run:
_do_delete(query=query, orm_model=orm_model, skip_archive=skip_archive, session=session)
session.commit()
def _confirm_delete(*, date: DateTime, tables: list[str]):
for_tables = f" for tables {tables!r}" if tables else ""
question = (
f"You have requested that we purge all data prior to {date}{for_tables}.\n"
f"This is irreversible. Consider backing up the tables first and / or doing a dry run "
f"with option --dry-run.\n"
f"Enter 'delete rows' (without quotes) to proceed."
)
print(question)
answer = input().strip()
if not answer == "delete rows":
raise SystemExit("User did not confirm; exiting.")
def _confirm_drop_archives(*, tables: list[str]):
# if length of tables is greater than 3, show the total count
if len(tables) > 3:
text_ = f"{len(tables)} archived tables prefixed with {ARCHIVE_TABLE_PREFIX}"
else:
text_ = f"the following archived tables {tables}"
question = (
f"You have requested that we drop {text_}.\n"
f"This is irreversible. Consider backing up the tables first \n"
)
print(question)
if len(tables) > 3:
show_tables = ask_yesno("Show tables? (y/n): ")
if show_tables:
print(tables, "\n")
answer = input("Enter 'drop archived tables' (without quotes) to proceed.\n").strip()
if not answer == "drop archived tables":
raise SystemExit("User did not confirm; exiting.")
def _print_config(*, configs: dict[str, _TableConfig]):
data = [x.readable_config for x in configs.values()]
AirflowConsole().print_as_table(data=data)
@contextmanager
def _suppress_with_logging(table, session):
"""
Suppresses errors but logs them.
Also stores the exception instance so it can be referred to after exiting context.
"""
try:
yield
except (OperationalError, ProgrammingError):
logger.warning("Encountered error when attempting to clean table '%s'. ", table)
logger.debug("Traceback for table '%s'", table, exc_info=True)
if session.is_active:
logger.debug("Rolling back transaction")
session.rollback()
def _effective_table_names(*, table_names: list[str] | None):
desired_table_names = set(table_names or config_dict)
effective_config_dict = {k: v for k, v in config_dict.items() if k in desired_table_names}
effective_table_names = set(effective_config_dict)
if desired_table_names != effective_table_names:
outliers = desired_table_names - effective_table_names
logger.warning(
"The following table(s) are not valid choices and will be skipped: %s", sorted(outliers)
)
if not effective_table_names:
raise SystemExit("No tables selected for db cleanup. Please choose valid table names.")
return effective_table_names, effective_config_dict
def _get_archived_table_names(table_names, session):
inspector = inspect(session.bind)
db_table_names = [x for x in inspector.get_table_names() if x.startswith(ARCHIVE_TABLE_PREFIX)]
effective_table_names, _ = _effective_table_names(table_names=table_names)
# Filter out tables that don't start with the archive prefix
archived_table_names = [
table_name
for table_name in db_table_names
if any("__" + x + "__" in table_name for x in effective_table_names)
]
return archived_table_names
@provide_session
def run_cleanup(
*,
clean_before_timestamp: DateTime,
table_names: list[str] | None = None,
dry_run: bool = False,
verbose: bool = False,
confirm: bool = True,
skip_archive: bool = False,
session: Session = NEW_SESSION,
):
"""
Purges old records in airflow metadata database.
The last non-externally-triggered dag run will always be kept in order to ensure
continuity of scheduled dag runs.
Where there are foreign key relationships, deletes will cascade, so that for
example if you clean up old dag runs, the associated task instances will
be deleted.
:param clean_before_timestamp: The timestamp before which data should be purged
:param table_names: Optional. List of table names to perform maintenance on. If list not provided,
will perform maintenance on all tables.
:param dry_run: If true, print rows meeting deletion criteria
:param verbose: If true, may provide more detailed output.
:param confirm: Require user input to confirm before processing deletions.
:param skip_archive: Set to True if you don't want the purged rows preservied in an archive table.
:param session: Session representing connection to the metadata database.
"""
clean_before_timestamp = timezone.coerce_datetime(clean_before_timestamp)
effective_table_names, effective_config_dict = _effective_table_names(table_names=table_names)
if dry_run:
print("Performing dry run for db cleanup.")
print(
f"Data prior to {clean_before_timestamp} would be purged "
f"from tables {effective_table_names} with the following config:\n"
)
_print_config(configs=effective_config_dict)
if not dry_run and confirm:
_confirm_delete(date=clean_before_timestamp, tables=sorted(effective_table_names))
existing_tables = reflect_tables(tables=None, session=session).tables
for table_name, table_config in effective_config_dict.items():
if table_name not in existing_tables:
logger.warning("Table %s not found. Skipping.", table_name)
continue
with _suppress_with_logging(table_name, session):
_cleanup_table(
clean_before_timestamp=clean_before_timestamp,
dry_run=dry_run,
verbose=verbose,
**table_config.__dict__,
skip_archive=skip_archive,
session=session,
)
session.commit()
@provide_session
def export_archived_records(
export_format,
output_path,
table_names=None,
drop_archives=False,
needs_confirm=True,
session: Session = NEW_SESSION,
):
"""Export archived data to the given output path in the given format."""
archived_table_names = _get_archived_table_names(table_names, session)
# If user chose to drop archives, check there are archive tables that exists
# before asking for confirmation
if drop_archives and archived_table_names and needs_confirm:
_confirm_drop_archives(tables=sorted(archived_table_names))
export_count = 0
dropped_count = 0
for table_name in archived_table_names:
logger.info("Exporting table %s", table_name)
_dump_table_to_file(
target_table=table_name,
file_path=os.path.join(output_path, f"{table_name}.{export_format}"),
export_format=export_format,
session=session,
)
export_count += 1
if drop_archives:
logger.info("Dropping archived table %s", table_name)
session.execute(text(f"DROP TABLE {table_name}"))
dropped_count += 1
logger.info("Total exported tables: %s, Total dropped tables: %s", export_count, dropped_count)
@provide_session
def drop_archived_tables(table_names, needs_confirm, session):
"""Drop archived tables."""
archived_table_names = _get_archived_table_names(table_names, session)
if needs_confirm and archived_table_names:
_confirm_drop_archives(tables=sorted(archived_table_names))
dropped_count = 0
for table_name in archived_table_names:
logger.info("Dropping archived table %s", table_name)
session.execute(text(f"DROP TABLE {table_name}"))
dropped_count += 1
logger.info("Total dropped tables: %s", dropped_count)