blob: 70a1a1e6b47f4a36c4055a1a5c39e2f67eed930a [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import sqlalchemy as sa
from alembic import context
from lazy_object_proxy import Proxy
######################################
# Note about this module:
#
# It loads the specific type dynamically at runtime. For IDE/typing support
# there is an associated db_types.pyi. If you add a new type in here, add a
# simple version in there too.
######################################
def _mssql_TIMESTAMP():
from sqlalchemy.dialects import mssql
class DATETIME2(mssql.DATETIME2):
def __init__(self, *args, precision=6, **kwargs):
super().__init__(*args, precision=precision, **kwargs)
return DATETIME2
def _mysql_TIMESTAMP():
from sqlalchemy.dialects import mysql
class TIMESTAMP(mysql.TIMESTAMP):
def __init__(self, *args, fsp=6, timezone=True, **kwargs):
super().__init__(*args, fsp=fsp, timezone=timezone, **kwargs)
return TIMESTAMP
def _sa_TIMESTAMP():
class TIMESTAMP(sa.TIMESTAMP):
def __init__(self, *args, timezone=True, **kwargs):
super().__init__(*args, timezone=timezone, **kwargs)
return TIMESTAMP
def _sa_StringID():
from airflow.models.base import StringID
return StringID
def __getattr__(name):
if name in ["TIMESTAMP", "StringID"]:
def lazy_load():
dialect = context.get_bind().dialect.name
module = globals()
# Lookup the type based on the dialect specific type, or fallback to the generic type
type_ = module.get(f"_{dialect}_{name}", None) or module.get(f"_sa_{name}")
val = module[name] = type_()
return val
# Prior to v1.4 of our Helm chart we didn't correctly initialize the Migration environment, so
# `context.get_bind()` would fail if called at the top level. To make it easier on migration writers
# we make the returned objects lazy.
return Proxy(lazy_load)
raise AttributeError(f"module {__name__} has no attribute {name}")