| #!/usr/bin/env python |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Migration script to port an Apache Airflow metadata DB to another DB engine. |
| |
| Call it with `python migrate_script.py --extract` on the source environment |
| to dump the configured airflow metadata database to a SQLite file called |
| `migration.db`. Then you can copy this file to the target environment |
| (or re-configure your database backend) and run `python migrate_script.py --restore`. |
| |
| Note that it is common sense probably that this script is assuming that schedulers, |
| workers, triggerer and webserver are stopped while running this script. It is also |
| advised to halt all running jobs and DAG runs. |
| |
| Database schema must match. |
| |
| Note that this script is made for Airflow 2.7.3 and is provided without any warranty. |
| """ |
| |
| # TODO test 2.7.3 |
| # TODO docs |
| from __future__ import annotations |
| |
| import argparse |
| import logging |
| import os |
| import subprocess |
| |
| from sqlalchemy import create_engine, delete, select, text |
| from sqlalchemy.orm import Session |
| |
| from airflow.auth.managers.fab.models import Action, Permission, RegisterUser, Resource, Role, User |
| from airflow.jobs.job import Job |
| from airflow.models.connection import Connection |
| from airflow.models.dag import DagModel, DagOwnerAttributes, DagTag |
| from airflow.models.dagcode import DagCode |
| from airflow.models.dagpickle import DagPickle |
| from airflow.models.dagrun import DagRun, DagRunNote |
| from airflow.models.dagwarning import DagWarning |
| from airflow.models.dataset import ( |
| DagScheduleDatasetReference, |
| DatasetDagRunQueue, |
| DatasetEvent, |
| DatasetModel, |
| TaskOutletDatasetReference, |
| ) |
| from airflow.models.db_callback_request import DbCallbackRequest |
| from airflow.models.errors import ImportError |
| from airflow.models.log import Log |
| from airflow.models.pool import Pool |
| from airflow.models.renderedtifields import RenderedTaskInstanceFields |
| from airflow.models.serialized_dag import SerializedDagModel |
| from airflow.models.slamiss import SlaMiss |
| from airflow.models.taskfail import TaskFail |
| from airflow.models.taskinstance import TaskInstance, TaskInstanceNote |
| from airflow.models.tasklog import LogTemplate |
| from airflow.models.taskmap import TaskMap |
| from airflow.models.taskreschedule import TaskReschedule |
| from airflow.models.trigger import Trigger |
| from airflow.models.variable import Variable |
| from airflow.models.xcom import BaseXCom, XCom |
| from airflow.settings import engine |
| |
| # configuration variables |
| airflow_db_url = engine.url |
| temp_db_url = "sqlite:///migration.db" |
| supported_db_versions = [ |
| # see https://airflow.apache.org/docs/apache-airflow/stable/migrations-ref.html |
| "405de8318b3a" # Airflow 2.7.3 |
| ] |
| |
| # initialise logging |
| logging.basicConfig(filename="migration.log", level=logging.DEBUG) |
| |
| |
| def copy_airflow_tables(source_engine, target_engine): |
| objects_to_migrate = [ |
| Action, |
| Resource, |
| Role, |
| User, |
| Permission, |
| RegisterUser, |
| Connection, |
| DagModel, |
| DagCode, |
| DagOwnerAttributes, |
| DagPickle, |
| DagTag, |
| Job, |
| LogTemplate, |
| DagRun, |
| DagRunNote, |
| DagWarning, |
| DatasetModel, |
| DagScheduleDatasetReference, |
| TaskOutletDatasetReference, |
| DatasetDagRunQueue, |
| DatasetEvent, |
| DbCallbackRequest, |
| ImportError, |
| Log, |
| Pool, |
| Trigger, |
| RenderedTaskInstanceFields, |
| SerializedDagModel, |
| SlaMiss, |
| TaskInstance, |
| TaskInstanceNote, |
| TaskFail, |
| TaskMap, |
| TaskReschedule, |
| Variable, |
| BaseXCom, |
| "ab_user_role", # besides the ORM objects some table which demand cleaning are listed |
| "ab_permission_view", |
| "ab_permission_view_role", |
| ] |
| source_session = Session(bind=source_engine) |
| target_session = Session(bind=target_engine) |
| target_session.autoflush = False |
| dialect_name = target_session.bind.dialect.name |
| quote = "`" if dialect_name == "mysql" else '"' |
| |
| # check that source DB is a supported version |
| db_version = target_session.scalar("SELECT * FROM alembic_version") |
| if db_version not in supported_db_versions: |
| raise ValueError(f"Unsupported Airflow Schema version {db_version}") |
| source_version = source_session.scalar("SELECT * FROM alembic_version") |
| if source_version != db_version: |
| raise ValueError( |
| f"Database schema must match. Source is {source_version}, destination is {db_version}." |
| ) |
| |
| # Deserialization fails, but we want to transfer the blob as original anyway, mock serialization away |
| def deserialize_mock(self: XCom): |
| return self.value |
| |
| BaseXCom.orm_deserialize_value = deserialize_mock |
| |
| # Step 1 - delete any leftovers, ensure all tables to be migrated are empty - use reverse order |
| for clz in reversed(objects_to_migrate): |
| if isinstance(clz, str): |
| logging.info("Cleaning table %s", clz) |
| target_session.execute(f"DELETE FROM {quote}{clz}{quote}") |
| else: |
| logging.info("Cleaning table %s", clz.__tablename__) |
| if clz == User: |
| # The user has a self-constraint, need to delete in batches not to violate cross-dependencies |
| continue_delete = True |
| while continue_delete: |
| filter_uids = set() |
| for uid in target_session.execute( |
| text("SELECT changed_by_fk FROM ab_user WHERE changed_by_fk IS NOT NULL") |
| ).fetchall(): |
| filter_uids.add(uid) |
| for uid in target_session.execute( |
| text("SELECT created_by_fk FROM ab_user WHERE created_by_fk IS NOT NULL") |
| ).fetchall(): |
| filter_uids.add(uid) |
| uid_list = ",".join(str(uid[0]) for uid in filter_uids) |
| if uid_list: |
| continue_delete = ( |
| target_session.execute( |
| f"DELETE FROM ab_user WHERE id NOT IN ({uid_list})" |
| ).rowcount |
| > 0 |
| ) |
| else: |
| continue_delete = target_session.execute(delete(clz)).rowcount > 0 |
| else: |
| target_session.execute(delete(clz)) |
| # Step 2 - copy all data over, use only ORM mapped tables |
| for clz in objects_to_migrate: |
| count = 0 |
| if not isinstance(clz, str): |
| logging.info("Migration of %s started", clz.__tablename__) |
| for item in source_session.scalars(select(clz)).unique(): |
| target_session.merge(item) |
| target_session.flush() |
| count += 1 |
| if count % 100 == 0: |
| logging.info("Migration of chunk finished, %i migrated", count) |
| logging.info("Migration of %s finished with %i rows", clz.__tablename__, count) |
| target_session.commit() |
| # Step 3 - update sequences to ensure new records continue with valid IDs auto-generated |
| for clz in objects_to_migrate: |
| count = 0 |
| if not isinstance(clz, str) and "id" in clz.__dict__: |
| logging.info("Resetting sequence value for %s", clz.__tablename__) |
| max = target_session.scalar(f"SELECT MAX(id) FROM {quote}{clz.__tablename__}{quote}") |
| if max: |
| if dialect_name == "postgresql": |
| target_session.execute( |
| f"ALTER SEQUENCE {quote}{clz.__tablename__}_id_seq{quote} RESTART WITH {max+1}" |
| ) |
| elif dialect_name == "sqlite": |
| pass # nothing to be done for sqlite |
| elif dialect_name == "mysql": |
| target_session.execute( |
| f"ALTER TABLE {quote}{clz.__tablename__}{quote} AUTO_INCREMENT = {max+1}" |
| ) |
| else: # e.g. "mssql" |
| raise Exception(f"Database type {dialect_name} not supported") |
| target_session.commit() |
| |
| |
| def main(extract: bool, restore: bool): |
| if extract == restore: |
| raise ValueError("Please specify what you want to do! Or use --help") |
| |
| if extract: |
| logging.info("Creating migration database %s", temp_db_url) |
| envs = os.environ.copy() |
| envs["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = temp_db_url |
| subprocess.check_call(args=["airflow", "db", "reset", "--yes"], env=envs) |
| |
| # source and target database |
| airflow_engine = create_engine(airflow_db_url) |
| temp_engine = create_engine(temp_db_url) |
| logging.info("Connection to databases established") |
| |
| if extract: |
| copy_airflow_tables(airflow_engine, temp_engine) |
| else: |
| copy_airflow_tables(temp_engine, airflow_engine) |
| logging.info("Migration completed!") |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--extract", help="Extracts the current Airflow database to a SQLite file", action="store_true" |
| ) |
| parser.add_argument( |
| "--restore", help="Restores from a SQLite to the current Airflow database", action="store_true" |
| ) |
| args = parser.parse_args() |
| main(args.extract, args.restore) |