| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import collections.abc |
| import contextlib |
| import enum |
| import itertools |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| import warnings |
| from dataclasses import dataclass |
| from tempfile import gettempdir |
| from typing import ( |
| TYPE_CHECKING, |
| Any, |
| Callable, |
| Generator, |
| Iterable, |
| Iterator, |
| Protocol, |
| Sequence, |
| TypeVar, |
| overload, |
| ) |
| |
| import attrs |
| from sqlalchemy import ( |
| Table, |
| and_, |
| column, |
| delete, |
| exc, |
| func, |
| inspect, |
| literal, |
| or_, |
| select, |
| table, |
| text, |
| tuple_, |
| ) |
| |
| import airflow |
| from airflow import settings |
| from airflow.configuration import conf |
| from airflow.exceptions import AirflowException |
| from airflow.models import import_all_models |
| from airflow.utils import helpers |
| |
| # TODO: remove create_session once we decide to break backward compatibility |
| from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401 |
| from airflow.utils.task_instance_session import get_current_task_instance_session |
| |
| if TYPE_CHECKING: |
| from alembic.runtime.environment import EnvironmentContext |
| from alembic.script import ScriptDirectory |
| from sqlalchemy.engine import Row |
| from sqlalchemy.orm import Query, Session |
| from sqlalchemy.sql.elements import ClauseElement, TextClause |
| from sqlalchemy.sql.selectable import Select |
| |
| from airflow.models.connection import Connection |
| from airflow.typing_compat import Self |
| |
| # TODO: Import this from sqlalchemy.orm instead when switching to SQLA 2. |
| # https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.MappedClassProtocol |
| class MappedClassProtocol(Protocol): |
| """Protocol for SQLALchemy model base.""" |
| |
| __tablename__: str |
| |
| |
| T = TypeVar("T") |
| |
| log = logging.getLogger(__name__) |
| |
| _REVISION_HEADS_MAP = { |
| "2.0.0": "e959f08ac86c", |
| "2.0.1": "82b7c48c147f", |
| "2.0.2": "2e42bb497a22", |
| "2.1.0": "a13f7613ad25", |
| "2.1.3": "97cdd93827b8", |
| "2.1.4": "ccde3e26fe78", |
| "2.2.0": "7b2661a43ba3", |
| "2.2.3": "be2bfac3da23", |
| "2.2.4": "587bdf053233", |
| "2.3.0": "b1b348e02d07", |
| "2.3.1": "1de7bc13c950", |
| "2.3.2": "3c94c427fdf6", |
| "2.3.3": "f5fcbda3e651", |
| "2.4.0": "ecb43d2a1842", |
| "2.4.2": "b0d31815b5a6", |
| "2.4.3": "e07f49787c9d", |
| "2.5.0": "290244fb8b83", |
| "2.6.0": "98ae134e6fff", |
| "2.6.2": "c804e5c76e3e", |
| "2.7.0": "405de8318b3a", |
| "2.8.0": "10b52ebd31f7", |
| "2.8.1": "88344c1d9134", |
| "2.9.0": "1949afb29106", |
| "2.9.2": "686269002441", |
| "2.10.0": "677fdbb7fc54", |
| } |
| |
| |
| def _format_airflow_moved_table_name(source_table, version, category): |
| return "__".join([settings.AIRFLOW_MOVED_TABLE_PREFIX, version.replace(".", "_"), category, source_table]) |
| |
| |
| @provide_session |
| def merge_conn(conn: Connection, session: Session = NEW_SESSION): |
| """Add new Connection.""" |
| if not session.scalar(select(1).where(conn.__class__.conn_id == conn.conn_id)): |
| session.add(conn) |
| session.commit() |
| |
| |
| @provide_session |
| def add_default_pool_if_not_exists(session: Session = NEW_SESSION): |
| """Add default pool if it does not exist.""" |
| from airflow.models.pool import Pool |
| |
| if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session): |
| default_pool = Pool( |
| pool=Pool.DEFAULT_POOL_NAME, |
| slots=conf.getint(section="core", key="default_pool_task_slot_count"), |
| description="Default pool", |
| include_deferred=False, |
| ) |
| session.add(default_pool) |
| session.commit() |
| |
| |
| @provide_session |
| def create_default_connections(session: Session = NEW_SESSION): |
| """Create default Airflow connections.""" |
| from airflow.models.connection import Connection |
| |
| merge_conn( |
| Connection( |
| conn_id="airflow_db", |
| conn_type="mysql", |
| host="mysql", |
| login="root", |
| password="", |
| schema="airflow", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="athena_default", |
| conn_type="athena", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="aws_default", |
| conn_type="aws", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="azure_batch_default", |
| conn_type="azure_batch", |
| login="<ACCOUNT_NAME>", |
| password="", |
| extra="""{"account_url": "<ACCOUNT_URL>"}""", |
| ) |
| ) |
| merge_conn( |
| Connection( |
| conn_id="azure_cosmos_default", |
| conn_type="azure_cosmos", |
| extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="azure_data_explorer_default", |
| conn_type="azure_data_explorer", |
| host="https://<CLUSTER>.kusto.windows.net", |
| extra="""{"auth_method": "<AAD_APP | AAD_APP_CERT | AAD_CREDS | AAD_DEVICE>", |
| "tenant": "<TENANT ID>", "certificate": "<APPLICATION PEM CERTIFICATE>", |
| "thumbprint": "<APPLICATION CERTIFICATE THUMBPRINT>"}""", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="azure_data_lake_default", |
| conn_type="azure_data_lake", |
| extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="azure_default", |
| conn_type="azure", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="cassandra_default", |
| conn_type="cassandra", |
| host="cassandra", |
| port=9042, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="databricks_default", |
| conn_type="databricks", |
| host="localhost", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="dingding_default", |
| conn_type="http", |
| host="", |
| password="", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="drill_default", |
| conn_type="drill", |
| host="localhost", |
| port=8047, |
| extra='{"dialect_driver": "drill+sadrill", "storage_plugin": "dfs"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="druid_broker_default", |
| conn_type="druid", |
| host="druid-broker", |
| port=8082, |
| extra='{"endpoint": "druid/v2/sql"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="druid_ingest_default", |
| conn_type="druid", |
| host="druid-overlord", |
| port=8081, |
| extra='{"endpoint": "druid/indexer/v1/task"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="elasticsearch_default", |
| conn_type="elasticsearch", |
| host="localhost", |
| schema="http", |
| port=9200, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="emr_default", |
| conn_type="emr", |
| extra=""" |
| { "Name": "default_job_flow_name", |
| "LogUri": "s3://my-emr-log-bucket/default_job_flow_location", |
| "ReleaseLabel": "emr-4.6.0", |
| "Instances": { |
| "Ec2KeyName": "mykey", |
| "Ec2SubnetId": "somesubnet", |
| "InstanceGroups": [ |
| { |
| "Name": "Master nodes", |
| "Market": "ON_DEMAND", |
| "InstanceRole": "MASTER", |
| "InstanceType": "r3.2xlarge", |
| "InstanceCount": 1 |
| }, |
| { |
| "Name": "Core nodes", |
| "Market": "ON_DEMAND", |
| "InstanceRole": "CORE", |
| "InstanceType": "r3.2xlarge", |
| "InstanceCount": 1 |
| } |
| ], |
| "TerminationProtected": false, |
| "KeepJobFlowAliveWhenNoSteps": false |
| }, |
| "Applications":[ |
| { "Name": "Spark" } |
| ], |
| "VisibleToAllUsers": true, |
| "JobFlowRole": "EMR_EC2_DefaultRole", |
| "ServiceRole": "EMR_DefaultRole", |
| "Tags": [ |
| { |
| "Key": "app", |
| "Value": "analytics" |
| }, |
| { |
| "Key": "environment", |
| "Value": "development" |
| } |
| ] |
| } |
| """, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="facebook_default", |
| conn_type="facebook_social", |
| extra=""" |
| { "account_id": "<AD_ACCOUNT_ID>", |
| "app_id": "<FACEBOOK_APP_ID>", |
| "app_secret": "<FACEBOOK_APP_SECRET>", |
| "access_token": "<FACEBOOK_AD_ACCESS_TOKEN>" |
| } |
| """, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="fs_default", |
| conn_type="fs", |
| extra='{"path": "/"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="ftp_default", |
| conn_type="ftp", |
| host="localhost", |
| port=21, |
| login="airflow", |
| password="airflow", |
| extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="google_cloud_default", |
| conn_type="google_cloud_platform", |
| schema="default", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="hive_cli_default", |
| conn_type="hive_cli", |
| port=10000, |
| host="localhost", |
| extra='{"use_beeline": true, "auth": ""}', |
| schema="default", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="hiveserver2_default", |
| conn_type="hiveserver2", |
| host="localhost", |
| schema="default", |
| port=10000, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="http_default", |
| conn_type="http", |
| host="https://www.httpbin.org/", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="iceberg_default", |
| conn_type="iceberg", |
| host="https://api.iceberg.io/ws/v1", |
| ), |
| session, |
| ) |
| merge_conn(Connection(conn_id="impala_default", conn_type="impala", host="localhost", port=21050)) |
| merge_conn( |
| Connection( |
| conn_id="kafka_default", |
| conn_type="kafka", |
| extra=json.dumps({"bootstrap.servers": "broker:29092"}), |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="kubernetes_default", |
| conn_type="kubernetes", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="kylin_default", |
| conn_type="kylin", |
| host="localhost", |
| port=7070, |
| login="ADMIN", |
| password="KYLIN", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="leveldb_default", |
| conn_type="leveldb", |
| host="localhost", |
| ), |
| session, |
| ) |
| merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session) |
| merge_conn( |
| Connection( |
| conn_id="local_mysql", |
| conn_type="mysql", |
| host="localhost", |
| login="airflow", |
| password="airflow", |
| schema="airflow", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="metastore_default", |
| conn_type="hive_metastore", |
| host="localhost", |
| extra='{"authMechanism": "PLAIN"}', |
| port=9083, |
| ), |
| session, |
| ) |
| merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session) |
| merge_conn( |
| Connection( |
| conn_id="mssql_default", |
| conn_type="mssql", |
| host="localhost", |
| port=1433, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="mysql_default", |
| conn_type="mysql", |
| login="root", |
| schema="airflow", |
| host="mysql", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="opsgenie_default", |
| conn_type="http", |
| host="", |
| password="", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="oracle_default", |
| conn_type="oracle", |
| host="localhost", |
| login="root", |
| password="password", |
| schema="schema", |
| port=1521, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="oss_default", |
| conn_type="oss", |
| extra="""{ |
| "auth_type": "AK", |
| "access_key_id": "<ACCESS_KEY_ID>", |
| "access_key_secret": "<ACCESS_KEY_SECRET>", |
| "region": "<YOUR_OSS_REGION>"} |
| """, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="pig_cli_default", |
| conn_type="pig_cli", |
| schema="default", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="pinot_admin_default", |
| conn_type="pinot", |
| host="localhost", |
| port=9000, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="pinot_broker_default", |
| conn_type="pinot", |
| host="localhost", |
| port=9000, |
| extra='{"endpoint": "/query", "schema": "http"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="postgres_default", |
| conn_type="postgres", |
| login="postgres", |
| password="airflow", |
| schema="airflow", |
| host="postgres", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="presto_default", |
| conn_type="presto", |
| host="localhost", |
| schema="hive", |
| port=3400, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="qdrant_default", |
| conn_type="qdrant", |
| host="qdrant", |
| port=6333, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="redis_default", |
| conn_type="redis", |
| host="redis", |
| port=6379, |
| extra='{"db": 0}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="redshift_default", |
| conn_type="redshift", |
| extra="""{ |
| "iam": true, |
| "cluster_identifier": "<REDSHIFT_CLUSTER_IDENTIFIER>", |
| "port": 5439, |
| "profile": "default", |
| "db_user": "awsuser", |
| "database": "dev", |
| "region": "" |
| }""", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="salesforce_default", |
| conn_type="salesforce", |
| login="username", |
| password="password", |
| extra='{"security_token": "security_token"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="segment_default", |
| conn_type="segment", |
| extra='{"write_key": "my-segment-write-key"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="sftp_default", |
| conn_type="sftp", |
| host="localhost", |
| port=22, |
| login="airflow", |
| extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="spark_default", |
| conn_type="spark", |
| host="yarn", |
| extra='{"queue": "root.default"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="sqlite_default", |
| conn_type="sqlite", |
| host=os.path.join(gettempdir(), "sqlite_default.db"), |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="ssh_default", |
| conn_type="ssh", |
| host="localhost", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="tableau_default", |
| conn_type="tableau", |
| host="https://tableau.server.url", |
| login="user", |
| password="password", |
| extra='{"site_id": "my_site"}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="tabular_default", |
| conn_type="tabular", |
| host="https://api.tabulardata.io/ws/v1", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="teradata_default", |
| conn_type="teradata", |
| host="localhost", |
| login="user", |
| password="password", |
| schema="schema", |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="trino_default", |
| conn_type="trino", |
| host="localhost", |
| schema="hive", |
| port=3400, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="vertica_default", |
| conn_type="vertica", |
| host="localhost", |
| port=5433, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="wasb_default", |
| conn_type="wasb", |
| extra='{"sas_token": null}', |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="webhdfs_default", |
| conn_type="hdfs", |
| host="localhost", |
| port=50070, |
| ), |
| session, |
| ) |
| merge_conn( |
| Connection( |
| conn_id="yandexcloud_default", |
| conn_type="yandexcloud", |
| schema="default", |
| ), |
| session, |
| ) |
| |
| |
| def _get_flask_db(sql_database_uri): |
| from flask import Flask |
| from flask_sqlalchemy import SQLAlchemy |
| |
| from airflow.www.session import AirflowDatabaseSessionInterface |
| |
| flask_app = Flask(__name__) |
| flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri |
| flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False |
| db = SQLAlchemy(flask_app) |
| AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") |
| return db |
| |
| |
| def _create_db_from_orm(session): |
| from alembic import command |
| |
| from airflow.models.base import Base |
| from airflow.providers.fab.auth_manager.models import Model |
| |
| def _create_flask_session_tbl(sql_database_uri): |
| db = _get_flask_db(sql_database_uri) |
| db.create_all() |
| |
| with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): |
| engine = session.get_bind().engine |
| Base.metadata.create_all(engine) |
| Model.metadata.create_all(engine) |
| _create_flask_session_tbl(engine.url) |
| # stamp the migration head |
| config = _get_alembic_config() |
| command.stamp(config, "head") |
| |
| |
| @provide_session |
| def initdb(session: Session = NEW_SESSION, load_connections: bool = True, use_migration_files: bool = False): |
| """Initialize Airflow database.""" |
| import_all_models() |
| |
| db_exists = _get_current_revision(session) |
| if db_exists or use_migration_files: |
| upgradedb(session=session, use_migration_files=use_migration_files) |
| else: |
| _create_db_from_orm(session=session) |
| if conf.getboolean("database", "LOAD_DEFAULT_CONNECTIONS") and load_connections: |
| create_default_connections(session=session) |
| # Add default pool & sync log_template |
| add_default_pool_if_not_exists(session=session) |
| synchronize_log_template(session=session) |
| |
| |
| def _get_alembic_config(): |
| from alembic.config import Config |
| |
| package_dir = os.path.dirname(airflow.__file__) |
| directory = os.path.join(package_dir, "migrations") |
| alembic_file = conf.get("database", "alembic_ini_file_path") |
| if os.path.isabs(alembic_file): |
| config = Config(alembic_file) |
| else: |
| config = Config(os.path.join(package_dir, alembic_file)) |
| config.set_main_option("script_location", directory.replace("%", "%%")) |
| config.set_main_option("sqlalchemy.url", settings.SQL_ALCHEMY_CONN.replace("%", "%%")) |
| return config |
| |
| |
| def _get_script_object(config=None) -> ScriptDirectory: |
| from alembic.script import ScriptDirectory |
| |
| if not config: |
| config = _get_alembic_config() |
| return ScriptDirectory.from_config(config) |
| |
| |
| def _get_current_revision(session): |
| from alembic.migration import MigrationContext |
| |
| conn = session.connection() |
| |
| migration_ctx = MigrationContext.configure(conn) |
| |
| return migration_ctx.get_current_revision() |
| |
| |
| def check_migrations(timeout): |
| """ |
| Wait for all airflow migrations to complete. |
| |
| :param timeout: Timeout for the migration in seconds |
| :return: None |
| """ |
| timeout = timeout or 1 # run the loop at least 1 |
| with _configured_alembic_environment() as env: |
| context = env.get_context() |
| source_heads = None |
| db_heads = None |
| for ticker in range(timeout): |
| source_heads = set(env.script.get_heads()) |
| db_heads = set(context.get_current_heads()) |
| if source_heads == db_heads: |
| return |
| time.sleep(1) |
| log.info("Waiting for migrations... %s second(s)", ticker) |
| raise TimeoutError( |
| f"There are still unapplied migrations after {timeout} seconds. Migration" |
| f"Head(s) in DB: {db_heads} | Migration Head(s) in Source Code: {source_heads}" |
| ) |
| |
| |
| @contextlib.contextmanager |
| def _configured_alembic_environment() -> Generator[EnvironmentContext, None, None]: |
| from alembic.runtime.environment import EnvironmentContext |
| |
| config = _get_alembic_config() |
| script = _get_script_object(config) |
| |
| with EnvironmentContext( |
| config, |
| script, |
| ) as env, settings.engine.connect() as connection: |
| alembic_logger = logging.getLogger("alembic") |
| level = alembic_logger.level |
| alembic_logger.setLevel(logging.WARNING) |
| env.configure(connection) |
| alembic_logger.setLevel(level) |
| |
| yield env |
| |
| |
| def check_and_run_migrations(): |
| """Check and run migrations if necessary. Only use in a tty.""" |
| with _configured_alembic_environment() as env: |
| context = env.get_context() |
| source_heads = set(env.script.get_heads()) |
| db_heads = set(context.get_current_heads()) |
| db_command = None |
| command_name = None |
| verb = None |
| if len(db_heads) < 1: |
| db_command = initdb |
| command_name = "init" |
| verb = "initialize" |
| elif source_heads != db_heads: |
| db_command = upgradedb |
| command_name = "upgrade" |
| verb = "upgrade" |
| |
| if sys.stdout.isatty() and verb: |
| print() |
| question = f"Please confirm database {verb} (or wait 4 seconds to skip it). Are you sure? [y/N]" |
| try: |
| answer = helpers.prompt_with_timeout(question, timeout=4, default=False) |
| if answer: |
| try: |
| db_command() |
| print(f"DB {verb} done") |
| except Exception as error: |
| from airflow.version import version |
| |
| print(error) |
| print( |
| "You still have unapplied migrations. " |
| f"You may need to {verb} the database by running `airflow db {command_name}`. ", |
| f"Make sure the command is run using Airflow version {version}.", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
| except AirflowException: |
| pass |
| elif source_heads != db_heads: |
| from airflow.version import version |
| |
| print( |
| f"ERROR: You need to {verb} the database. Please run `airflow db {command_name}`. " |
| f"Make sure the command is run using Airflow version {version}.", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
| |
| |
| def _reserialize_dags(*, session: Session) -> None: |
| from airflow.models.dagbag import DagBag |
| from airflow.models.serialized_dag import SerializedDagModel |
| |
| session.execute(delete(SerializedDagModel).execution_options(synchronize_session=False)) |
| dagbag = DagBag(collect_dags=False) |
| dagbag.collect_dags(only_if_updated=False) |
| dagbag.sync_to_db(session=session) |
| |
| |
| @provide_session |
| def synchronize_log_template(*, session: Session = NEW_SESSION) -> None: |
| """Synchronize log template configs with table. |
| |
| This checks if the last row fully matches the current config values, and |
| insert a new row if not. |
| """ |
| # NOTE: SELECT queries in this function are INTENTIONALLY written with the |
| # SQL builder style, not the ORM query API. This avoids configuring the ORM |
| # unless we need to insert something, speeding up CLI in general. |
| |
| from airflow.models.tasklog import LogTemplate |
| |
| metadata = reflect_tables([LogTemplate], session) |
| log_template_table: Table | None = metadata.tables.get(LogTemplate.__tablename__) |
| |
| if log_template_table is None: |
| log.info("Log template table does not exist (added in 2.3.0); skipping log template sync.") |
| return |
| |
| filename = conf.get("logging", "log_filename_template") |
| elasticsearch_id = conf.get("elasticsearch", "log_id_template") |
| |
| stored = session.execute( |
| select( |
| log_template_table.c.filename, |
| log_template_table.c.elasticsearch_id, |
| ) |
| .order_by(log_template_table.c.id.desc()) |
| .limit(1) |
| ).first() |
| |
| # If we have an empty table, and the default values exist, we will seed the |
| # table with values from pre 2.3.0, so old logs will still be retrievable. |
| if not stored: |
| is_default_log_id = elasticsearch_id == conf.get_default_value("elasticsearch", "log_id_template") |
| is_default_filename = filename == conf.get_default_value("logging", "log_filename_template") |
| if is_default_log_id and is_default_filename: |
| session.add( |
| LogTemplate( |
| filename="{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log", |
| elasticsearch_id="{dag_id}-{task_id}-{execution_date}-{try_number}", |
| ) |
| ) |
| |
| # Before checking if the _current_ value exists, we need to check if the old config value we upgraded in |
| # place exists! |
| pre_upgrade_filename = conf.upgraded_values.get(("logging", "log_filename_template"), filename) |
| pre_upgrade_elasticsearch_id = conf.upgraded_values.get( |
| ("elasticsearch", "log_id_template"), elasticsearch_id |
| ) |
| if pre_upgrade_filename != filename or pre_upgrade_elasticsearch_id != elasticsearch_id: |
| # The previous non-upgraded value likely won't be the _latest_ value (as after we've recorded the |
| # recorded the upgraded value it will be second-to-newest), so we'll have to just search which is okay |
| # as this is a table with a tiny number of rows |
| row = session.execute( |
| select(log_template_table.c.id) |
| .where( |
| or_( |
| log_template_table.c.filename == pre_upgrade_filename, |
| log_template_table.c.elasticsearch_id == pre_upgrade_elasticsearch_id, |
| ) |
| ) |
| .order_by(log_template_table.c.id.desc()) |
| .limit(1) |
| ).first() |
| if not row: |
| session.add( |
| LogTemplate(filename=pre_upgrade_filename, elasticsearch_id=pre_upgrade_elasticsearch_id) |
| ) |
| |
| if not stored or stored.filename != filename or stored.elasticsearch_id != elasticsearch_id: |
| session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id)) |
| |
| |
| def check_conn_id_duplicates(session: Session) -> Iterable[str]: |
| """ |
| Check unique conn_id in connection table. |
| |
| :param session: session of the sqlalchemy |
| """ |
| from airflow.models.connection import Connection |
| |
| try: |
| dups = session.scalars( |
| select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1) |
| ).all() |
| except (exc.OperationalError, exc.ProgrammingError): |
| # fallback if tables hasn't been created yet |
| session.rollback() |
| return |
| if dups: |
| yield ( |
| "Seems you have non unique conn_id in connection table.\n" |
| "You have to manage those duplicate connections " |
| "before upgrading the database.\n" |
| f"Duplicated conn_id: {dups}" |
| ) |
| |
| |
| def check_username_duplicates(session: Session) -> Iterable[str]: |
| """ |
| Check unique username in User & RegisterUser table. |
| |
| :param session: session of the sqlalchemy |
| :rtype: str |
| """ |
| from airflow.providers.fab.auth_manager.models import RegisterUser, User |
| |
| for model in [User, RegisterUser]: |
| dups = [] |
| try: |
| dups = session.execute( |
| select(model.username) # type: ignore[attr-defined] |
| .group_by(model.username) # type: ignore[attr-defined] |
| .having(func.count() > 1) |
| ).all() |
| except (exc.OperationalError, exc.ProgrammingError): |
| # fallback if tables hasn't been created yet |
| session.rollback() |
| if dups: |
| yield ( |
| f"Seems you have mixed case usernames in {model.__table__.name} table.\n" # type: ignore |
| "You have to rename or delete those mixed case usernames " |
| "before upgrading the database.\n" |
| f"usernames with mixed cases: {[dup.username for dup in dups]}" |
| ) |
| |
| |
| def reflect_tables(tables: list[MappedClassProtocol | str] | None, session): |
| """ |
| When running checks prior to upgrades, we use reflection to determine current state of the database. |
| |
| This function gets the current state of each table in the set of models |
| provided and returns a SqlAlchemy metadata object containing them. |
| """ |
| import sqlalchemy.schema |
| |
| bind = session.bind |
| metadata = sqlalchemy.schema.MetaData() |
| |
| if tables is None: |
| metadata.reflect(bind=bind, resolve_fks=False) |
| else: |
| for tbl in tables: |
| try: |
| table_name = tbl if isinstance(tbl, str) else tbl.__tablename__ |
| metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False) |
| except exc.InvalidRequestError: |
| continue |
| return metadata |
| |
| |
| def check_table_for_duplicates( |
| *, session: Session, table_name: str, uniqueness: list[str], version: str |
| ) -> Iterable[str]: |
| """ |
| Check table for duplicates, given a list of columns which define the uniqueness of the table. |
| |
| Usage example: |
| |
| .. code-block:: python |
| |
| def check_task_fail_for_duplicates(session): |
| from airflow.models.taskfail import TaskFail |
| |
| metadata = reflect_tables([TaskFail], session) |
| task_fail = metadata.tables.get(TaskFail.__tablename__) # type: ignore |
| if task_fail is None: # table not there |
| return |
| if "run_id" in task_fail.columns: # upgrade already applied |
| return |
| yield from check_table_for_duplicates( |
| table_name=task_fail.name, |
| uniqueness=["dag_id", "task_id", "execution_date"], |
| session=session, |
| version="2.3", |
| ) |
| |
| :param table_name: table name to check |
| :param uniqueness: uniqueness constraint to evaluate against |
| :param session: session of the sqlalchemy |
| """ |
| minimal_table_obj = table(table_name, *(column(x) for x in uniqueness)) |
| try: |
| subquery = session.execute( |
| select(minimal_table_obj, func.count().label("dupe_count")) |
| .group_by(*(text(x) for x in uniqueness)) |
| .having(func.count() > text("1")) |
| .subquery() |
| ) |
| dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count))) |
| if not dupe_count: |
| # there are no duplicates; nothing to do. |
| return |
| |
| log.warning("Found %s duplicates in table %s. Will attempt to move them.", dupe_count, table_name) |
| |
| metadata = reflect_tables(tables=[table_name], session=session) |
| if table_name not in metadata.tables: |
| yield f"Table {table_name} does not exist in the database." |
| |
| # We can't use the model here since it may differ from the db state due to |
| # this function is run prior to migration. Use the reflected table instead. |
| table_obj = metadata.tables[table_name] |
| |
| _move_duplicate_data_to_new_table( |
| session=session, |
| source_table=table_obj, |
| subquery=subquery, |
| uniqueness=uniqueness, |
| target_table_name=_format_airflow_moved_table_name(table_name, version, "duplicates"), |
| ) |
| except (exc.OperationalError, exc.ProgrammingError): |
| # fallback if `table_name` hasn't been created yet |
| session.rollback() |
| |
| |
| def check_conn_type_null(session: Session) -> Iterable[str]: |
| """ |
| Check nullable conn_type column in Connection table. |
| |
| :param session: session of the sqlalchemy |
| """ |
| from airflow.models.connection import Connection |
| |
| try: |
| n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all() |
| except (exc.OperationalError, exc.ProgrammingError, exc.InternalError): |
| # fallback if tables hasn't been created yet |
| session.rollback() |
| return |
| |
| if n_nulls: |
| yield ( |
| "The conn_type column in the connection " |
| "table must contain content.\n" |
| "Make sure you don't have null " |
| "in the conn_type column.\n" |
| f"Null conn_type conn_id: {n_nulls}" |
| ) |
| |
| |
| def _format_dangling_error(source_table, target_table, invalid_count, reason): |
| noun = "row" if invalid_count == 1 else "rows" |
| return ( |
| f"The {source_table} table has {invalid_count} {noun} {reason}, which " |
| f"is invalid. We could not move them out of the way because the " |
| f"{target_table} table already exists in your database. Please either " |
| f"drop the {target_table} table, or manually delete the invalid rows " |
| f"from the {source_table} table." |
| ) |
| |
| |
| def check_run_id_null(session: Session) -> Iterable[str]: |
| from airflow.models.dagrun import DagRun |
| |
| metadata = reflect_tables([DagRun], session) |
| |
| # We can't use the model here since it may differ from the db state due to |
| # this function is run prior to migration. Use the reflected table instead. |
| dagrun_table = metadata.tables.get(DagRun.__tablename__) |
| if dagrun_table is None: |
| return |
| |
| invalid_dagrun_filter = or_( |
| dagrun_table.c.dag_id.is_(None), |
| dagrun_table.c.run_id.is_(None), |
| dagrun_table.c.execution_date.is_(None), |
| ) |
| invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter)) |
| if invalid_dagrun_count > 0: |
| dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling") |
| if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names(): |
| yield _format_dangling_error( |
| source_table=dagrun_table.name, |
| target_table=dagrun_dangling_table_name, |
| invalid_count=invalid_dagrun_count, |
| reason="with a NULL dag_id, run_id, or execution_date", |
| ) |
| return |
| |
| bind = session.get_bind() |
| dialect_name = bind.dialect.name |
| _create_table_as( |
| dialect_name=dialect_name, |
| source_query=dagrun_table.select(invalid_dagrun_filter), |
| target_table_name=dagrun_dangling_table_name, |
| source_table_name=dagrun_table.name, |
| session=session, |
| ) |
| delete = dagrun_table.delete().where(invalid_dagrun_filter) |
| session.execute(delete) |
| |
| |
| def _create_table_as( |
| *, |
| session, |
| dialect_name: str, |
| source_query: Query, |
| target_table_name: str, |
| source_table_name: str, |
| ): |
| """ |
| Create a new table with rows from query. |
| |
| We have to handle CTAS differently for different dialects. |
| """ |
| if dialect_name == "mysql": |
| # MySQL with replication needs this split in to two queries, so just do it for all MySQL |
| # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT. |
| session.execute(text(f"CREATE TABLE {target_table_name} LIKE {source_table_name}")) |
| session.execute( |
| text( |
| f"INSERT INTO {target_table_name} {source_query.selectable.compile(bind=session.get_bind())}" |
| ) |
| ) |
| else: |
| # Postgres and SQLite both support the same "CREATE TABLE a AS SELECT ..." syntax |
| select_table = source_query.selectable.compile(bind=session.get_bind()) |
| session.execute(text(f"CREATE TABLE {target_table_name} AS {select_table}")) |
| |
| |
| def _move_dangling_data_to_new_table( |
| session, source_table: Table, source_query: Query, target_table_name: str |
| ): |
| bind = session.get_bind() |
| dialect_name = bind.dialect.name |
| |
| # First: Create moved rows from new table |
| log.debug("running CTAS for table %s", target_table_name) |
| _create_table_as( |
| dialect_name=dialect_name, |
| source_query=source_query, |
| target_table_name=target_table_name, |
| source_table_name=source_table.name, |
| session=session, |
| ) |
| session.commit() |
| |
| target_table = source_table.to_metadata(source_table.metadata, name=target_table_name) |
| log.debug("checking whether rows were moved for table %s", target_table_name) |
| moved_rows_exist_query = select(1).select_from(target_table).limit(1) |
| first_moved_row = session.execute(moved_rows_exist_query).all() |
| session.commit() |
| |
| if not first_moved_row: |
| log.debug("no rows moved; dropping %s", target_table_name) |
| # no bad rows were found; drop moved rows table. |
| target_table.drop(bind=session.get_bind(), checkfirst=True) |
| else: |
| log.debug("rows moved; purging from %s", source_table.name) |
| if dialect_name == "sqlite": |
| pk_cols = source_table.primary_key.columns |
| |
| delete = source_table.delete().where( |
| tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery()) |
| ) |
| else: |
| delete = source_table.delete().where( |
| and_(col == target_table.c[col.name] for col in source_table.primary_key.columns) |
| ) |
| log.debug(delete.compile()) |
| session.execute(delete) |
| session.commit() |
| |
| log.debug("exiting move function") |
| |
| |
| def _dangling_against_dag_run(session, source_table, dag_run): |
| """Given a source table, we generate a subquery that will return 1 for every row that has a dagrun.""" |
| source_to_dag_run_join_cond = and_( |
| source_table.c.dag_id == dag_run.c.dag_id, |
| source_table.c.execution_date == dag_run.c.execution_date, |
| ) |
| |
| return ( |
| select(*(c.label(c.name) for c in source_table.c)) |
| .join(dag_run, source_to_dag_run_join_cond, isouter=True) |
| .where(dag_run.c.dag_id.is_(None)) |
| ) |
| |
| |
| def _dangling_against_task_instance(session, source_table, dag_run, task_instance): |
| """ |
| Given a source table, generate a subquery that will return 1 for every row that has a valid task instance. |
| |
| This is used to identify rows that need to be removed from tables prior to adding a TI fk. |
| |
| Since this check is applied prior to running the migrations, we have to use different |
| query logic depending on which revision the database is at. |
| |
| """ |
| if "run_id" not in task_instance.c: |
| # db is < 2.2.0 |
| dr_join_cond = and_( |
| source_table.c.dag_id == dag_run.c.dag_id, |
| source_table.c.execution_date == dag_run.c.execution_date, |
| ) |
| ti_join_cond = and_( |
| dag_run.c.dag_id == task_instance.c.dag_id, |
| dag_run.c.execution_date == task_instance.c.execution_date, |
| source_table.c.task_id == task_instance.c.task_id, |
| ) |
| else: |
| # db is 2.2.0 <= version < 2.3.0 |
| dr_join_cond = and_( |
| source_table.c.dag_id == dag_run.c.dag_id, |
| source_table.c.execution_date == dag_run.c.execution_date, |
| ) |
| ti_join_cond = and_( |
| dag_run.c.dag_id == task_instance.c.dag_id, |
| dag_run.c.run_id == task_instance.c.run_id, |
| source_table.c.task_id == task_instance.c.task_id, |
| ) |
| |
| return ( |
| select(*(c.label(c.name) for c in source_table.c)) |
| .outerjoin(dag_run, dr_join_cond) |
| .outerjoin(task_instance, ti_join_cond) |
| .where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) |
| ) |
| |
| |
| def _move_duplicate_data_to_new_table( |
| session, source_table: Table, subquery: Query, uniqueness: list[str], target_table_name: str |
| ): |
| """ |
| When adding a uniqueness constraint we first should ensure that there are no duplicate rows. |
| |
| This function accepts a subquery that should return one record for each row with duplicates (e.g. |
| a group by with having count(*) > 1). We select from ``source_table`` getting all rows matching the |
| subquery result and store in ``target_table_name``. Then to purge the duplicates from the source table, |
| we do a DELETE FROM with a join to the target table (which now contains the dupes). |
| |
| :param session: sqlalchemy session for metadata db |
| :param source_table: table to purge dupes from |
| :param subquery: the subquery that returns the duplicate rows |
| :param uniqueness: the string list of columns used to define the uniqueness for the table. used in |
| building the DELETE FROM join condition. |
| :param target_table_name: name of the table in which to park the duplicate rows |
| """ |
| bind = session.get_bind() |
| dialect_name = bind.dialect.name |
| |
| query = ( |
| select(*(source_table.c[x.name].label(str(x.name)) for x in source_table.columns)) |
| .select_from(source_table) |
| .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in uniqueness))) |
| ) |
| |
| _create_table_as( |
| session=session, |
| dialect_name=dialect_name, |
| source_query=query, |
| target_table_name=target_table_name, |
| source_table_name=source_table.name, |
| ) |
| |
| # we must ensure that the CTAS table is created prior to the DELETE step since we have to join to it |
| session.commit() |
| |
| metadata = reflect_tables([target_table_name], session) |
| target_table = metadata.tables[target_table_name] |
| where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in uniqueness)) |
| |
| if dialect_name == "sqlite": |
| subq = query.selectable.with_only_columns([text(f"{source_table}.ROWID")]) |
| delete = source_table.delete().where(column("ROWID").in_(subq)) |
| else: |
| delete = source_table.delete(where_clause) |
| |
| session.execute(delete) |
| |
| |
| def check_bad_references(session: Session) -> Iterable[str]: |
| """ |
| Go through each table and look for records that can't be mapped to a dag run. |
| |
| When we find such "dangling" rows we back them up in a special table and delete them |
| from the main table. |
| |
| Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id` in many tables. |
| """ |
| from airflow.models.dagrun import DagRun |
| from airflow.models.renderedtifields import RenderedTaskInstanceFields |
| from airflow.models.taskfail import TaskFail |
| from airflow.models.taskinstance import TaskInstance |
| from airflow.models.taskreschedule import TaskReschedule |
| from airflow.models.xcom import XCom |
| |
| @dataclass |
| class BadReferenceConfig: |
| """ |
| Bad reference config class. |
| |
| :param bad_rows_func: function that returns subquery which determines whether bad rows exist |
| :param join_tables: table objects referenced in subquery |
| :param ref_table: information-only identifier for categorizing the missing ref |
| """ |
| |
| bad_rows_func: Callable |
| join_tables: list[str] |
| ref_table: str |
| |
| missing_dag_run_config = BadReferenceConfig( |
| bad_rows_func=_dangling_against_dag_run, |
| join_tables=["dag_run"], |
| ref_table="dag_run", |
| ) |
| |
| missing_ti_config = BadReferenceConfig( |
| bad_rows_func=_dangling_against_task_instance, |
| join_tables=["dag_run", "task_instance"], |
| ref_table="task_instance", |
| ) |
| |
| models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [ |
| (TaskInstance, "2.2", missing_dag_run_config), |
| (TaskReschedule, "2.2", missing_ti_config), |
| (RenderedTaskInstanceFields, "2.3", missing_ti_config), |
| (TaskFail, "2.3", missing_ti_config), |
| (XCom, "2.3", missing_ti_config), |
| ] |
| metadata = reflect_tables([*(x[0] for x in models_list), DagRun, TaskInstance], session) |
| |
| if ( |
| not metadata.tables |
| or metadata.tables.get(DagRun.__tablename__) is None |
| or metadata.tables.get(TaskInstance.__tablename__) is None |
| ): |
| # Key table doesn't exist -- likely empty DB. |
| return |
| |
| existing_table_names = set(inspect(session.get_bind()).get_table_names()) |
| errored = False |
| |
| for model, change_version, bad_ref_cfg in models_list: |
| log.debug("checking model %s", model.__tablename__) |
| # We can't use the model here since it may differ from the db state due to |
| # this function is run prior to migration. Use the reflected table instead. |
| source_table = metadata.tables.get(model.__tablename__) # type: ignore |
| if source_table is None: |
| continue |
| |
| # Migration already applied, don't check again. |
| if "run_id" in source_table.columns: |
| continue |
| |
| func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables} |
| bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs) |
| |
| dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, "dangling") |
| if dangling_table_name in existing_table_names: |
| invalid_row_count = get_query_count(bad_rows_query, session=session) |
| if invalid_row_count: |
| yield _format_dangling_error( |
| source_table=source_table.name, |
| target_table=dangling_table_name, |
| invalid_count=invalid_row_count, |
| reason=f"without a corresponding {bad_ref_cfg.ref_table} row", |
| ) |
| errored = True |
| continue |
| |
| log.debug("moving data for table %s", source_table.name) |
| _move_dangling_data_to_new_table( |
| session, |
| source_table, |
| bad_rows_query, |
| dangling_table_name, |
| ) |
| |
| if errored: |
| session.rollback() |
| else: |
| session.commit() |
| |
| |
| @provide_session |
| def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]: |
| """:session: session of the sqlalchemy.""" |
| check_functions: tuple[Callable[..., Iterable[str]], ...] = ( |
| check_conn_id_duplicates, |
| check_conn_type_null, |
| check_run_id_null, |
| check_bad_references, |
| check_username_duplicates, |
| ) |
| for check_fn in check_functions: |
| log.debug("running check function %s", check_fn.__name__) |
| yield from check_fn(session=session) |
| |
| |
| def _offline_migration(migration_func: Callable, config, revision): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| logging.disable(logging.CRITICAL) |
| migration_func(config, revision, sql=True) |
| logging.disable(logging.NOTSET) |
| |
| |
| def print_happy_cat(message): |
| if sys.stdout.isatty(): |
| size = os.get_terminal_size().columns |
| else: |
| size = 0 |
| print(message.center(size)) |
| print("""/\\_/\\""".center(size)) |
| print("""(='_' )""".center(size)) |
| print("""(,(") (")""".center(size)) |
| print("""^^^""".center(size)) |
| return |
| |
| |
| def _revision_greater(config, this_rev, base_rev): |
| # Check if there is history between the revisions and the start revision |
| # This ensures that the revisions are above `min_revision` |
| script = _get_script_object(config) |
| try: |
| list(script.revision_map.iterate_revisions(upper=this_rev, lower=base_rev)) |
| return True |
| except Exception: |
| return False |
| |
| |
| def _revisions_above_min_for_offline(config, revisions) -> None: |
| """ |
| Check that all supplied revision ids are above the minimum revision for the dialect. |
| |
| :param config: Alembic config |
| :param revisions: list of Alembic revision ids |
| :return: None |
| """ |
| dbname = settings.engine.dialect.name |
| if dbname == "sqlite": |
| raise SystemExit("Offline migration not supported for SQLite.") |
| min_version, min_revision = ("2.2.0", "7b2661a43ba3") if dbname == "mssql" else ("2.0.0", "e959f08ac86c") |
| |
| # Check if there is history between the revisions and the start revision |
| # This ensures that the revisions are above `min_revision` |
| for rev in revisions: |
| if not _revision_greater(config, rev, min_revision): |
| raise ValueError( |
| f"Error while checking history for revision range {min_revision}:{rev}. " |
| f"Check that {rev} is a valid revision. " |
| f"For dialect {dbname!r}, supported revision for offline migration is from {min_revision} " |
| f"which corresponds to Airflow {min_version}." |
| ) |
| |
| |
| @provide_session |
| def upgradedb( |
| *, |
| to_revision: str | None = None, |
| from_revision: str | None = None, |
| show_sql_only: bool = False, |
| reserialize_dags: bool = True, |
| session: Session = NEW_SESSION, |
| use_migration_files: bool = False, |
| ): |
| """ |
| Upgrades the DB. |
| |
| :param to_revision: Optional Alembic revision ID to upgrade *to*. |
| If omitted, upgrades to latest revision. |
| :param from_revision: Optional Alembic revision ID to upgrade *from*. |
| Not compatible with ``sql_only=False``. |
| :param show_sql_only: if True, migration statements will be printed but not executed. |
| :param session: sqlalchemy session with connection to Airflow metadata database |
| :return: None |
| """ |
| if from_revision and not show_sql_only: |
| raise AirflowException("`from_revision` only supported with `sql_only=True`.") |
| |
| # alembic adds significant import time, so we import it lazily |
| if not settings.SQL_ALCHEMY_CONN: |
| raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is a critical assertion.") |
| from alembic import command |
| |
| import_all_models() |
| |
| config = _get_alembic_config() |
| |
| if show_sql_only: |
| if not from_revision: |
| from_revision = _get_current_revision(session) |
| |
| if not to_revision: |
| script = _get_script_object() |
| to_revision = script.get_current_head() |
| |
| if to_revision == from_revision: |
| print_happy_cat("No migrations to apply; nothing to do.") |
| return |
| |
| if not _revision_greater(config, to_revision, from_revision): |
| raise ValueError( |
| f"Requested *to* revision {to_revision} is older than *from* revision {from_revision}. " |
| "Please check your requested versions / revisions." |
| ) |
| _revisions_above_min_for_offline(config=config, revisions=[from_revision, to_revision]) |
| |
| _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}") |
| return # only running sql; our job is done |
| |
| errors_seen = False |
| for err in _check_migration_errors(session=session): |
| if not errors_seen: |
| log.error("Automatic migration is not available") |
| errors_seen = True |
| log.error("%s", err) |
| |
| if errors_seen: |
| exit(1) |
| |
| if not to_revision and not _get_current_revision(session=session) and not use_migration_files: |
| # Don't load default connections |
| # New DB; initialize and exit |
| initdb(session=session, load_connections=False) |
| return |
| with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): |
| import sqlalchemy.pool |
| |
| log.info("Creating tables") |
| val = os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE") |
| try: |
| # Reconfigure the ORM to use _EXACTLY_ one connection, otherwise some db engines hang forever |
| # trying to ALTER TABLEs |
| os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1" |
| settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool) |
| command.upgrade(config, revision=to_revision or "heads") |
| finally: |
| if val is None: |
| os.environ.pop("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE") |
| else: |
| os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = val |
| settings.reconfigure_orm() |
| |
| if reserialize_dags: |
| _reserialize_dags(session=session) |
| add_default_pool_if_not_exists(session=session) |
| synchronize_log_template(session=session) |
| |
| |
| @provide_session |
| def resetdb(session: Session = NEW_SESSION, skip_init: bool = False, use_migration_files: bool = False): |
| """Clear out the database.""" |
| if not settings.engine: |
| raise RuntimeError("The settings.engine must be set. This is a critical assertion") |
| log.info("Dropping tables that exist") |
| |
| import_all_models() |
| |
| connection = settings.engine.connect() |
| |
| with create_global_lock(session=session, lock=DBLocks.MIGRATIONS), connection.begin(): |
| drop_airflow_models(connection) |
| drop_airflow_moved_tables(connection) |
| |
| if not skip_init: |
| initdb(session=session, use_migration_files=use_migration_files) |
| |
| |
| @provide_session |
| def bootstrap_dagbag(session: Session = NEW_SESSION): |
| from airflow.models.dag import DAG |
| from airflow.models.dagbag import DagBag |
| |
| dagbag = DagBag() |
| # Save DAGs in the ORM |
| dagbag.sync_to_db(session=session) |
| |
| # Deactivate the unknown ones |
| DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session) |
| |
| |
| @provide_session |
| def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: Session = NEW_SESSION): |
| """ |
| Downgrade the airflow metastore schema to a prior version. |
| |
| :param to_revision: The alembic revision to downgrade *to*. |
| :param show_sql_only: if True, print sql statements but do not run them |
| :param from_revision: if supplied, alembic revision to dawngrade *from*. This may only |
| be used in conjunction with ``sql=True`` because if we actually run the commands, |
| we should only downgrade from the *current* revision. |
| :param session: sqlalchemy session for connection to airflow metadata database |
| """ |
| if from_revision and not show_sql_only: |
| raise ValueError( |
| "`from_revision` can't be combined with `sql=False`. When actually " |
| "applying a downgrade (instead of just generating sql), we always " |
| "downgrade from current revision." |
| ) |
| |
| if not settings.SQL_ALCHEMY_CONN: |
| raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.") |
| |
| # alembic adds significant import time, so we import it lazily |
| from alembic import command |
| |
| log.info("Attempting downgrade to revision %s", to_revision) |
| config = _get_alembic_config() |
| |
| with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): |
| if show_sql_only: |
| log.warning("Generating sql scripts for manual migration.") |
| if not from_revision: |
| from_revision = _get_current_revision(session) |
| revision_range = f"{from_revision}:{to_revision}" |
| _offline_migration(command.downgrade, config=config, revision=revision_range) |
| else: |
| log.info("Applying downgrade migrations.") |
| command.downgrade(config, revision=to_revision, sql=show_sql_only) |
| |
| |
| def drop_airflow_models(connection): |
| """ |
| Drop all airflow models. |
| |
| :param connection: SQLAlchemy Connection |
| :return: None |
| """ |
| from airflow.models.base import Base |
| from airflow.providers.fab.auth_manager.models import Model |
| |
| Base.metadata.drop_all(connection) |
| Model.metadata.drop_all(connection) |
| db = _get_flask_db(connection.engine.url) |
| db.drop_all() |
| # alembic adds significant import time, so we import it lazily |
| from alembic.migration import MigrationContext |
| |
| migration_ctx = MigrationContext.configure(connection) |
| version = migration_ctx._version |
| if inspect(connection).has_table(version.name): |
| version.drop(connection) |
| |
| |
| def drop_airflow_moved_tables(connection): |
| from airflow.models.base import Base |
| from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX |
| |
| tables = set(inspect(connection).get_table_names()) |
| to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)] |
| for tbl in to_delete: |
| tbl.drop(settings.engine, checkfirst=False) |
| Base.metadata.remove(tbl) |
| |
| |
| @provide_session |
| def check(session: Session = NEW_SESSION): |
| """ |
| Check if the database works. |
| |
| :param session: session of the sqlalchemy |
| """ |
| session.execute(text("select 1 as is_alive;")) |
| log.info("Connection successful.") |
| |
| |
| @enum.unique |
| class DBLocks(enum.IntEnum): |
| """ |
| Cross-db Identifiers for advisory global database locks. |
| |
| Postgres uses int64 lock ids so we use the integer value, MySQL uses names, so we |
| call ``str()`, which is implemented using the ``_name_`` field. |
| """ |
| |
| MIGRATIONS = enum.auto() |
| SCHEDULER_CRITICAL_SECTION = enum.auto() |
| |
| def __str__(self): |
| return f"airflow_{self._name_}" |
| |
| |
| @contextlib.contextmanager |
| def create_global_lock( |
| session: Session, |
| lock: DBLocks, |
| lock_timeout: int = 1800, |
| ) -> Generator[None, None, None]: |
| """Contextmanager that will create and teardown a global db lock.""" |
| conn = session.get_bind().connect() |
| dialect = conn.dialect |
| try: |
| if dialect.name == "postgresql": |
| conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout}) |
| conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) |
| elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6): |
| conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout}) |
| |
| yield |
| finally: |
| if dialect.name == "postgresql": |
| conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT")) |
| (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone() |
| if not unlocked: |
| raise RuntimeError("Error releasing DB lock!") |
| elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6): |
| conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)}) |
| |
| |
| def compare_type(context, inspected_column, metadata_column, inspected_type, metadata_type): |
| """ |
| Compare types between ORM and DB . |
| |
| return False if the metadata_type is the same as the inspected_type |
| or None to allow the default implementation to compare these |
| types. a return value of True means the two types do not |
| match and should result in a type change operation. |
| """ |
| if context.dialect.name == "mysql": |
| from sqlalchemy import String |
| from sqlalchemy.dialects import mysql |
| |
| if isinstance(inspected_type, mysql.VARCHAR) and isinstance(metadata_type, String): |
| # This is a hack to get around MySQL VARCHAR collation |
| # not being possible to change from utf8_bin to utf8mb3_bin. |
| # We only make sure lengths are the same |
| if inspected_type.length != metadata_type.length: |
| return True |
| return False |
| return None |
| |
| |
| def compare_server_default( |
| context, inspected_column, metadata_column, inspected_default, metadata_default, rendered_metadata_default |
| ): |
| """ |
| Compare server defaults between ORM and DB . |
| |
| return True if the defaults are different, False if not, or None to allow the default implementation |
| to compare these defaults |
| |
| In SQLite: task_instance.map_index & task_reschedule.map_index |
| are not comparing accurately. Sometimes they are equal, sometimes they are not. |
| Alembic warned that this feature has varied accuracy depending on backends. |
| See: (https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime. |
| environment.EnvironmentContext.configure.params.compare_server_default) |
| """ |
| dialect_name = context.connection.dialect.name |
| if dialect_name in ["sqlite"]: |
| return False |
| if ( |
| dialect_name == "mysql" |
| and metadata_column.name == "pool_slots" |
| and metadata_column.table.name == "task_instance" |
| ): |
| # We removed server_default value in ORM to avoid expensive migration |
| # (it was removed in postgres DB in migration head 7b2661a43ba3 ). |
| # As a side note, server default value here was only actually needed for the migration |
| # where we added the column in the first place -- now that it exists and all |
| # existing rows are populated with a value this server default is never used. |
| return False |
| return None |
| |
| |
| def get_sqla_model_classes(): |
| """ |
| Get all SQLAlchemy class mappers. |
| |
| SQLAlchemy < 1.4 does not support registry.mappers so we use |
| try/except to handle it. |
| """ |
| from airflow.models.base import Base |
| |
| try: |
| return [mapper.class_ for mapper in Base.registry.mappers] |
| except AttributeError: |
| return Base._decl_class_registry.values() |
| |
| |
| def get_query_count(query_stmt: Select, *, session: Session) -> int: |
| """Get count of a query. |
| |
| A SELECT COUNT() FROM is issued against the subquery built from the |
| given statement. The ORDER BY clause is stripped from the statement |
| since it's unnecessary for COUNT, and can impact query planning and |
| degrade performance. |
| |
| :meta private: |
| """ |
| count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery()) |
| return session.scalar(count_stmt) |
| |
| |
| def check_query_exists(query_stmt: Select, *, session: Session) -> bool: |
| """Check whether there is at least one row matching a query. |
| |
| A SELECT 1 FROM is issued against the subquery built from the given |
| statement. The ORDER BY clause is stripped from the statement since it's |
| unnecessary, and can impact query planning and degrade performance. |
| |
| :meta private: |
| """ |
| count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery()) |
| return session.scalar(count_stmt) |
| |
| |
| def exists_query(*where: ClauseElement, session: Session) -> bool: |
| """Check whether there is at least one row matching given clauses. |
| |
| This does a SELECT 1 WHERE ... LIMIT 1 and check the result. |
| |
| :meta private: |
| """ |
| stmt = select(literal(True)).where(*where).limit(1) |
| return session.scalar(stmt) is not None |
| |
| |
| @attrs.define(slots=True) |
| class LazySelectSequence(Sequence[T]): |
| """List-like interface to lazily access a database model query. |
| |
| The intended use case is inside a task execution context, where we manage an |
| active SQLAlchemy session in the background. |
| |
| This is an abstract base class. Each use case should subclass, and implement |
| the following static methods: |
| |
| * ``_rebuild_select`` is called when a lazy sequence is unpickled. Since it |
| is not easy to pickle SQLAlchemy constructs, this class serializes the |
| SELECT statements into plain text to storage. This method is called on |
| deserialization to convert the textual clause back into an ORM SELECT. |
| * ``_process_row`` is called when an item is accessed. The lazy sequence |
| uses ``session.execute()`` to fetch rows from the database, and this |
| method should know how to process each row into a value. |
| |
| :meta private: |
| """ |
| |
| _select_asc: ClauseElement |
| _select_desc: ClauseElement |
| _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session) |
| _len: int | None = attrs.field(init=False, default=None) |
| |
| @classmethod |
| def from_select( |
| cls, |
| select: Select, |
| *, |
| order_by: Sequence[ClauseElement], |
| session: Session | None = None, |
| ) -> Self: |
| s1 = select |
| for col in order_by: |
| s1 = s1.order_by(col.asc()) |
| s2 = select |
| for col in order_by: |
| s2 = s2.order_by(col.desc()) |
| return cls(s1, s2, session=session or get_current_task_instance_session()) |
| |
| @staticmethod |
| def _rebuild_select(stmt: TextClause) -> Select: |
| """Rebuild a textual statement into an ORM-configured SELECT statement. |
| |
| This should do something like ``select(field).from_statement(stmt)`` to |
| reconfigure ORM information to the textual SQL statement. |
| """ |
| raise NotImplementedError |
| |
| @staticmethod |
| def _process_row(row: Row) -> T: |
| """Process a SELECT-ed row into the end value.""" |
| raise NotImplementedError |
| |
| def __repr__(self) -> str: |
| counter = "item" if (length := len(self)) == 1 else "items" |
| return f"LazySelectSequence([{length} {counter}])" |
| |
| def __str__(self) -> str: |
| counter = "item" if (length := len(self)) == 1 else "items" |
| return f"LazySelectSequence([{length} {counter}])" |
| |
| def __getstate__(self) -> Any: |
| # We don't want to go to the trouble of serializing SQLAlchemy objects. |
| # Converting the statement into a SQL string is the best we can get. |
| # The literal_binds compile argument inlines all the values into the SQL |
| # string to simplify cross-process commuinication as much as possible. |
| # Theoratically we can do the same for count(), but I think it should be |
| # performant enough to calculate only that eagerly. |
| s1 = str(self._select_asc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True})) |
| s2 = str(self._select_desc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True})) |
| return (s1, s2, len(self)) |
| |
| def __setstate__(self, state: Any) -> None: |
| s1, s2, self._len = state |
| self._select_asc = self._rebuild_select(text(s1)) |
| self._select_desc = self._rebuild_select(text(s2)) |
| self._session = get_current_task_instance_session() |
| |
| def __bool__(self) -> bool: |
| return check_query_exists(self._select_asc, session=self._session) |
| |
| def __eq__(self, other: Any) -> bool: |
| if not isinstance(other, collections.abc.Sequence): |
| return NotImplemented |
| z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) |
| return all(x == y for x, y in z) |
| |
| def __reversed__(self) -> Iterator[T]: |
| return iter(self._process_row(r) for r in self._session.execute(self._select_desc)) |
| |
| def __iter__(self) -> Iterator[T]: |
| return iter(self._process_row(r) for r in self._session.execute(self._select_asc)) |
| |
| def __len__(self) -> int: |
| if self._len is None: |
| self._len = get_query_count(self._select_asc, session=self._session) |
| return self._len |
| |
| @overload |
| def __getitem__(self, key: int) -> T: ... |
| |
| @overload |
| def __getitem__(self, key: slice) -> Sequence[T]: ... |
| |
| def __getitem__(self, key: int | slice) -> T | Sequence[T]: |
| if isinstance(key, int): |
| if key >= 0: |
| stmt = self._select_asc.offset(key) |
| else: |
| stmt = self._select_desc.offset(-1 - key) |
| if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None: |
| raise IndexError(key) |
| return self._process_row(row) |
| elif isinstance(key, slice): |
| # This implements the slicing syntax. We want to optimize negative |
| # slicing (e.g. seq[-10:]) by not doing an additional COUNT query |
| # if possible. We can do this unless the start and stop have |
| # different signs (i.e. one is positive and another negative). |
| start, stop, reverse = _coerce_slice(key) |
| if start >= 0: |
| if stop is None: |
| stmt = self._select_asc.offset(start) |
| elif stop >= 0: |
| stmt = self._select_asc.slice(start, stop) |
| else: |
| stmt = self._select_asc.slice(start, len(self) + stop) |
| rows = [self._process_row(row) for row in self._session.execute(stmt)] |
| if reverse: |
| rows.reverse() |
| else: |
| if stop is None: |
| stmt = self._select_desc.limit(-start) |
| elif stop < 0: |
| stmt = self._select_desc.slice(-stop, -start) |
| else: |
| stmt = self._select_desc.slice(len(self) - stop, -start) |
| rows = [self._process_row(row) for row in self._session.execute(stmt)] |
| if not reverse: |
| rows.reverse() |
| return rows |
| raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") |
| |
| |
| def _coerce_index(value: Any) -> int | None: |
| """Check slice attribute's type and convert it to int. |
| |
| See CPython documentation on this: |
| https://docs.python.org/3/reference/datamodel.html#object.__index__ |
| """ |
| if value is None or isinstance(value, int): |
| return value |
| if (index := getattr(value, "__index__", None)) is not None: |
| return index() |
| raise TypeError("slice indices must be integers or None or have an __index__ method") |
| |
| |
| def _coerce_slice(key: slice) -> tuple[int, int | None, bool]: |
| """Check slice content and convert it for SQL. |
| |
| See CPython documentation on this: |
| https://docs.python.org/3/reference/datamodel.html#slice-objects |
| """ |
| if key.step is None or key.step == 1: |
| reverse = False |
| elif key.step == -1: |
| reverse = True |
| else: |
| raise ValueError("non-trivial slice step not supported") |
| return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse |