| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| from collections import defaultdict |
| from contextlib import contextmanager |
| |
| |
| def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]: |
| """ |
| This function return primary and unique constraint |
| along with column name. Some tables like `task_instance` |
| is missing the primary key constraint name and the name is |
| auto-generated by the SQL server. so this function helps to |
| retrieve any primary or unique constraint name. |
| :param conn: sql connection object |
| :param table_name: table name |
| :return: a dictionary of ((constraint name, constraint type), column name) of table |
| """ |
| query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME |
| FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc |
| JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME |
| WHERE tc.TABLE_NAME = '{table_name}' AND |
| (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE' |
| or UPPER(tc.CONSTRAINT_TYPE) = 'FOREIGN KEY') |
| """ |
| result = conn.execute(query).fetchall() |
| constraint_dict = defaultdict(lambda: defaultdict(list)) |
| for constraint, constraint_type, col_name in result: |
| constraint_dict[constraint_type][constraint].append(col_name) |
| return constraint_dict |
| |
| |
| @contextmanager |
| def disable_sqlite_fkeys(op): |
| if op.get_bind().dialect.name == "sqlite": |
| op.execute("PRAGMA foreign_keys=off") |
| yield op |
| op.execute("PRAGMA foreign_keys=on") |
| else: |
| yield op |