blob: 78f925ac84b5311538ea906b779982a7d81f48ef [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections import defaultdict
from contextlib import contextmanager
def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]:
"""
This function return primary and unique constraint
along with column name. Some tables like `task_instance`
is missing the primary key constraint name and the name is
auto-generated by the SQL server. so this function helps to
retrieve any primary or unique constraint name.
:param conn: sql connection object
:param table_name: table name
:return: a dictionary of ((constraint name, constraint type), column name) of table
"""
query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
WHERE tc.TABLE_NAME = '{table_name}' AND
(tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE'
or UPPER(tc.CONSTRAINT_TYPE) = 'FOREIGN KEY')
"""
result = conn.execute(query).fetchall()
constraint_dict = defaultdict(lambda: defaultdict(list))
for constraint, constraint_type, col_name in result:
constraint_dict[constraint_type][constraint].append(col_name)
return constraint_dict
@contextmanager
def disable_sqlite_fkeys(op):
if op.get_bind().dialect.name == "sqlite":
op.execute("PRAGMA foreign_keys=off")
yield op
op.execute("PRAGMA foreign_keys=on")
else:
yield op