| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| '''This module is intended to standardize workflows when working with various databases |
| such as Impala, Postgresql, etc. Even with pep-249 (DB API 2), workflows differ |
| slightly. For example Postgresql does not allow changing databases from within a |
| connection. |
| |
| ''' |
| import hashlib |
| import impala.dbapi |
| import re |
| import shelve |
| from abc import ABCMeta, abstractmethod |
| from contextlib import closing |
| from decimal import Decimal as PyDecimal |
| from itertools import combinations, ifilter, izip |
| from logging import getLogger |
| from os import symlink, unlink |
| from pyparsing import ( |
| alphanums, |
| delimitedList, |
| Forward, |
| Group, |
| Literal, |
| nums, |
| Suppress, |
| Word) |
| from tempfile import gettempdir |
| from textwrap import dedent |
| from threading import Lock |
| from time import time |
| |
| from tests.comparison.common import ( |
| ArrayColumn, |
| Column, |
| MapColumn, |
| StructColumn, |
| Table, |
| TableExprList) |
| from tests.comparison.db_types import ( |
| Char, |
| Decimal, |
| Double, |
| EXACT_TYPES, |
| Float, |
| get_char_class, |
| get_decimal_class, |
| get_varchar_class, |
| Int, |
| String, |
| Timestamp, |
| TinyInt, |
| VarChar) |
| |
| LOG = getLogger(__name__) |
| |
| HIVE = "HIVE" |
| IMPALA = "IMPALA" |
| MYSQL = "MYSQL" |
| ORACLE = "ORACLE" |
| POSTGRESQL = "POSTGRESQL" |
| |
| |
| class DbCursor(object): |
| '''Wraps a DB API 2 cursor to provide access to the related conn. This class |
| implements the DB API 2 interface by delegation. |
| |
| ''' |
| |
| @staticmethod |
| def describe_common_tables(cursors): |
| '''Find and return a TableExprList containing Table objects that the given conns |
| have in common. |
| ''' |
| common_table_names = None |
| for cursor in cursors: |
| table_names = set(cursor.list_table_names()) |
| if common_table_names is None: |
| common_table_names = table_names |
| else: |
| common_table_names &= table_names |
| common_table_names = sorted(common_table_names) |
| |
| tables = TableExprList() |
| for table_name in common_table_names: |
| common_table = None |
| mismatch = False |
| for cursor in cursors: |
| table = cursor.describe_table(table_name) |
| if common_table is None: |
| common_table = table |
| continue |
| if not table.cols: |
| LOG.debug('%s has no remaining columns', table_name) |
| mismatch = True |
| break |
| if len(common_table.cols) != len(table.cols): |
| LOG.debug('Ignoring table %s.' |
| ' It has a different number of columns across databases.', table_name) |
| mismatch = True |
| break |
| if common_table.primary_key_names != table.primary_key_names: |
| LOG.debug( |
| 'Ignoring table {name} because of differing primary keys: ' |
| '{common_table_keys} vs. {table_keys}'.format( |
| name=table_name, common_table_keys=common_table.primary_key_names, |
| table_keys=table.primary_key_names)) |
| mismatch = True |
| break |
| for left, right in izip(common_table.cols, table.cols): |
| if not (left.name == right.name and left.type == right.type): |
| LOG.debug('Ignoring table %s. It has different columns %s vs %s.' % |
| (table_name, left, right)) |
| mismatch = True |
| break |
| if mismatch: |
| break |
| if not mismatch: |
| tables.append(common_table) |
| |
| return tables |
| |
| SQL_TYPE_PATTERN = re.compile(r'([^()]+)(\((\d+,? ?)*\))?') |
| TYPE_NAME_ALIASES = \ |
| dict((type_.name().upper(), type_.name().upper()) for type_ in EXACT_TYPES) |
| TYPES_BY_NAME = dict((type_.name().upper(), type_) for type_ in EXACT_TYPES) |
| EXACT_TYPES_TO_SQL = dict((type_, type_.name().upper()) for type_ in EXACT_TYPES) |
| |
| @classmethod |
| def make_insert_sql_from_data(cls, table, rows): |
| if not rows: |
| raise Exception('At least one row is required') |
| if not table.cols: |
| raise Exception('At least one col is required') |
| |
| sql = 'INSERT INTO %s VALUES ' % table.name |
| for row_idx, row in enumerate(rows): |
| if row_idx > 0: |
| sql += ', ' |
| sql += '(' |
| for col_idx, col in enumerate(table.cols): |
| if col_idx > 0: |
| sql += ', ' |
| val = row[col_idx] |
| if val is None: |
| sql += 'NULL' |
| elif issubclass(col.type, Timestamp): |
| sql += "TIMESTAMP '%s'" % val |
| elif issubclass(col.type, Char): |
| sql += "'%s'" % val.replace("'", "''") |
| else: |
| sql += str(val) |
| sql += ')' |
| return sql |
| |
| def __init__(self, conn, cursor): |
| self._conn = conn |
| self._cursor = cursor |
| |
| def __getattr__(self, attr): |
| return getattr(self._cursor, attr) |
| |
| def __setattr__(self, attr, value): |
| # Transfer unknown attributes to the underlying cursor. |
| if attr not in ["_conn", "_cursor"]: |
| setattr(self._cursor, attr, value) |
| object.__setattr__(self, attr, value) |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| self.close(quiet=True) |
| |
| @property |
| def db_type(self): |
| return self._conn.db_type |
| |
| @property |
| def conn(self): |
| return self._conn |
| |
| @property |
| def db_name(self): |
| return self._conn.db_name |
| |
| def execute(self, sql, *args, **kwargs): |
| LOG.debug('%s: %s' % (self.db_type, sql)) |
| if self.conn.sql_log: |
| self.conn.sql_log.write('\nQuery: %s' % sql) |
| return self._cursor.execute(sql, *args, **kwargs) |
| |
| def execute_and_fetchall(self, sql, *args, **kwargs): |
| self.execute(sql, *args, **kwargs) |
| return self.fetchall() |
| |
| def close(self, quiet=False): |
| try: |
| self._cursor.close() |
| except Exception as e: |
| if quiet: |
| LOG.debug('Error closing cursor: %s', e, exc_info=True) |
| else: |
| raise e |
| |
| def reconnect(self): |
| self.conn.reconnect() |
| self._cursor = self.conn.cursor()._cursor |
| |
| def list_db_names(self): |
| '''Return a list of database names always in lowercase.''' |
| rows = self.execute_and_fetchall(self.make_list_db_names_sql()) |
| return [row[0].lower() for row in rows] |
| |
| def make_list_db_names_sql(self): |
| return 'SHOW DATABASES' |
| |
| def create_db(self, db_name): |
| LOG.info("Creating database %s", db_name) |
| db_name = db_name.lower() |
| self.execute('CREATE DATABASE ' + db_name) |
| |
| def create_db_if_not_exists(self, db_name): |
| LOG.info("Creating database %s", db_name) |
| db_name = db_name.lower() |
| self.execute('CREATE DATABASE IF NOT EXISTS ' + db_name) |
| |
| def drop_db_if_exists(self, db_name): |
| '''This should not be called from a conn to the database being dropped.''' |
| db_name = db_name.lower() |
| if db_name not in self.list_db_names(): |
| return |
| if self.db_name and self.db_name.lower() == db_name: |
| raise Exception('Cannot drop database while still connected to it') |
| self.drop_db(db_name) |
| |
| def drop_db(self, db_name): |
| LOG.info("Dropping database %s", db_name) |
| db_name = db_name.lower() |
| self.execute('DROP DATABASE ' + db_name) |
| |
| def ensure_empty_db(self, db_name): |
| self.drop_db_if_exists(db_name) |
| self.create_db(db_name) |
| |
| def list_table_names(self): |
| '''Return a list of table names always in lowercase.''' |
| rows = self.execute_and_fetchall(self.make_list_table_names_sql()) |
| return [row[0].lower() for row in rows] |
| |
| def make_list_table_names_sql(self): |
| return 'SHOW TABLES' |
| |
| def parse_col_desc(self, data_type): |
| ''' Returns a prased output based on type describe output. |
| data_type is string that should look like this: |
| "bigint" |
| or like this: |
| "array<struct< |
| field_51:int, |
| field_52:bigint, |
| field_53:int, |
| field_54:boolean |
| >>" |
| In the first case, this method would return: 'bigint' |
| In the second case, it would return |
| ['array', |
| ['struct', |
| ['field_51', 'int'], |
| ['field_52', 'bigint'], |
| ['field_53', 'int'], |
| ['field_54', 'boolean']]] |
| This output is used to create the appropriate columns by self.create_column(). |
| ''' |
| |
| COMMA, LPAR, RPAR, COLON, LBRA, RBRA = map(Suppress, ",<>:()") |
| |
| t_bigint = Literal('bigint') |
| t_int = Literal('int') |
| t_integer = Literal('integer') |
| t_smallint = Literal('smallint') |
| t_tinyint = Literal('tinyint') |
| t_boolean = Literal('boolean') |
| t_string = Literal('string') |
| t_timestamp = Literal('timestamp') |
| t_timestamp_without_time_zone = Literal('timestamp without time zone') |
| t_float = Literal('float') |
| t_double = Literal('double') |
| t_real = Literal('real') |
| t_double_precision = Literal('double precision') |
| |
| t_decimal = Group(Literal('decimal') + LBRA + Word(nums) + COMMA + Word(nums) + RBRA) |
| t_numeric = Group(Literal('numeric') + LBRA + Word(nums) + COMMA + Word(nums) + RBRA) |
| t_char = Group(Literal('char') + LBRA + Word(nums) + RBRA) |
| t_character = Group(Literal('character') + LBRA + Word(nums) + RBRA) |
| t_varchar = (Group(Literal('varchar') + LBRA + Word(nums) + RBRA) | |
| Literal('varchar')) |
| t_character_varying = Group(Literal('character varying') + LBRA + Word(nums) + RBRA) |
| |
| t_struct = Forward() |
| t_array = Forward() |
| t_map = Forward() |
| |
| complex_type = (t_struct | t_array | t_map) |
| |
| any_type = ( |
| complex_type | |
| t_bigint | |
| t_int | |
| t_integer | |
| t_smallint | |
| t_tinyint | |
| t_boolean | |
| t_string | |
| t_timestamp | |
| t_timestamp_without_time_zone | |
| t_float | |
| t_double | |
| t_real | |
| t_double_precision | |
| t_decimal | |
| t_numeric | |
| t_char | |
| t_character | |
| t_character_varying | |
| t_varchar) |
| |
| struct_field_name = Word(alphanums + '_') |
| struct_field_pair = Group(struct_field_name + COLON + any_type) |
| |
| t_struct << Group(Literal('struct') + LPAR + delimitedList(struct_field_pair) + RPAR) |
| t_array << Group(Literal('array') + LPAR + any_type + RPAR) |
| t_map << Group(Literal('map') + LPAR + any_type + COMMA + any_type + RPAR) |
| |
| return any_type.parseString(data_type)[0] |
| |
| def create_column(self, col_name, col_type): |
| ''' Takes the output from parse_col_desc and creates the right column type. This |
| method returns one of Column, ArrayColumn, MapColumn, StructColumn.''' |
| if isinstance(col_type, str): |
| if col_type.upper() == 'VARCHAR': |
| col_type = 'STRING' |
| type_name = self.TYPE_NAME_ALIASES.get(col_type.upper()) |
| return Column(owner=None, |
| name=col_name.lower(), |
| exact_type=self.TYPES_BY_NAME[type_name]) |
| |
| general_class = col_type[0] |
| |
| if general_class.upper() == 'ARRAY': |
| return ArrayColumn( |
| owner=None, |
| name=col_name.lower(), |
| item=self.create_column(col_name='item', col_type=col_type[1])) |
| |
| if general_class.upper() == 'MAP': |
| return MapColumn( |
| owner=None, |
| name=col_name.lower(), |
| key=self.create_column(col_name='key', col_type=col_type[1]), |
| value=self.create_column(col_name='value', col_type=col_type[2])) |
| |
| if general_class.upper() == 'STRUCT': |
| struct_col = StructColumn(owner=None, name=col_name.lower()) |
| for field_name, field_type in col_type[1:]: |
| struct_col.add_col(self.create_column(field_name, field_type)) |
| return struct_col |
| |
| general_class = self.TYPE_NAME_ALIASES.get(col_type[0].upper()) |
| |
| if general_class.upper() == 'DECIMAL': |
| return Column(owner=None, |
| name=col_name.lower(), |
| exact_type=get_decimal_class(int(col_type[1]), int(col_type[2]))) |
| |
| if general_class.upper() == 'CHAR': |
| return Column(owner=None, |
| name=col_name.lower(), |
| exact_type=get_char_class(int(col_type[1]))) |
| |
| if general_class.upper() == 'VARCHAR': |
| type_size = int(col_type[1]) |
| if type_size <= VarChar.MAX: |
| cur_type = get_varchar_class(type_size) |
| else: |
| cur_type = self.TYPES_BY_NAME['STRING'] |
| return Column(owner=None, |
| name=col_name.lower(), |
| exact_type=cur_type) |
| |
| raise Exception('unable to parse: {0}, type: {1}'.format(col_name, col_type)) |
| |
| def create_table_from_describe(self, table_name, describe_rows): |
| primary_key_names = self._fetch_primary_key_names(table_name) |
| table = Table(table_name.lower()) |
| for row in describe_rows: |
| col_name, data_type = row[:2] |
| col_type = self.parse_col_desc(data_type) |
| col = self.create_column(col_name, col_type) |
| col.is_primary_key = col_name in primary_key_names |
| table.add_col(col) |
| return table |
| |
| def describe_table(self, table_name): |
| '''Return a Table with table and col names always in lowercase.''' |
| describe_rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name)) |
| table = self.create_table_from_describe(table_name, describe_rows) |
| self.load_unique_col_metadata(table) |
| return table |
| |
| def make_describe_table_sql(self, table_name): |
| return 'DESCRIBE ' + table_name |
| |
| def parse_data_type(self, type_name, type_size): |
| if type_name in ('DECIMAL', 'NUMERIC'): |
| return get_decimal_class(*type_size) |
| if type_name == 'CHAR': |
| return get_char_class(*type_size) |
| if type_name == 'VARCHAR': |
| if type_size and type_size[0] <= VarChar.MAX: |
| return get_varchar_class(*type_size) |
| type_name = 'STRING' |
| return self.TYPES_BY_NAME[type_name] |
| |
| def create_table(self, table): |
| LOG.info('Creating table %s', table.name) |
| if not table.cols: |
| raise Exception('At least one col is required') |
| table_sql = self.make_create_table_sql(table) |
| LOG.debug(table_sql) |
| self.execute(table_sql) |
| LOG.debug('Created table %s', table.name) |
| |
| def make_create_table_sql(self, table): |
| column_declarations = [] |
| primary_key_names = [] |
| for col in table.cols: |
| if col.is_primary_key: |
| null_constraint = '' |
| primary_key_names.append(col.name) |
| elif self.conn.data_types_are_implictly_nullable: |
| null_constraint = '' |
| else: |
| null_constraint = ' NULL' |
| column_declaration = '{name} {col_type}{null_constraint}'.format( |
| name=col.name, col_type=self.get_sql_for_data_type(col.exact_type), |
| null_constraint=null_constraint) |
| column_declarations.append(column_declaration) |
| |
| if primary_key_names: |
| primary_key_constraint = ', PRIMARY KEY ({keys})'.format( |
| keys=', '.join(primary_key_names)) |
| else: |
| primary_key_constraint = '' |
| |
| create_table = ('CREATE TABLE {table_name} (' |
| '{all_columns}' |
| '{primary_key_constraint}' |
| ')'.format( |
| table_name=table.name, all_columns=', '.join(column_declarations), |
| primary_key_constraint=primary_key_constraint)) |
| return create_table |
| |
| def get_sql_for_data_type(self, data_type): |
| if issubclass(data_type, VarChar): |
| return 'VARCHAR(%s)' % data_type.MAX |
| if issubclass(data_type, Char): |
| return 'CHAR(%s)' % data_type.MAX |
| if issubclass(data_type, Decimal): |
| return 'DECIMAL(%s, %s)' % (data_type.MAX_DIGITS, data_type.MAX_FRACTIONAL_DIGITS) |
| return self.EXACT_TYPES_TO_SQL[data_type] |
| |
| def drop_table(self, table_name, if_exists=True): |
| LOG.info('Dropping table %s', table_name) |
| self.execute('DROP TABLE IF EXISTS ' + table_name.lower()) |
| |
| def drop_view(self, view_name, if_exists=True): |
| LOG.info('Dropping view %s', view_name) |
| self.execute('DROP VIEW IF EXISTS ' + view_name.lower()) |
| |
| def index_table(self, table_name): |
| LOG.info('Indexing table %s', table_name) |
| table = self.describe_table(table_name) |
| for col in table.cols: |
| index_name = '%s_%s' % (table_name, col.name) |
| if self.db_name: |
| index_name = '%s_%s' % (self.db_name, index_name) |
| # Older versions of Postgres require index name to be included in the statement. |
| # In the past, there were some failures when trying to index a table with long |
| # column names because there is a limit to how long index name is allowed to be, |
| # This is why index_name is hashed and truncated to 10 characters. |
| index_name = 'ind_' + hashlib.sha256(index_name).hexdigest()[:10] |
| self.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name)) |
| |
| def search_for_unique_cols(self, table=None, table_name=None, depth=2): |
| if not table: |
| table = self.describe_table(table_name) |
| sql_templ = 'SELECT COUNT(*) FROM %s GROUP BY %%s HAVING COUNT(*) > 1' % table.name |
| unique_cols = list() |
| for current_depth in xrange(1, depth + 1): |
| for cols in combinations(table.cols, current_depth): # redundant combos excluded |
| cols = set(cols) |
| if any(ifilter(lambda unique_subset: unique_subset < cols, unique_cols)): |
| # cols contains a combo known to be unique |
| continue |
| col_names = ', '.join(col.name for col in cols) |
| sql = sql_templ % col_names |
| LOG.debug('Checking column combo (%s) for uniqueness' % col_names) |
| self.execute(sql) |
| if not self.fetchone(): |
| LOG.debug('Found unique column combo (%s)' % col_names) |
| unique_cols.append(cols) |
| return unique_cols |
| |
| def persist_unique_col_metadata(self, table): |
| if not table.unique_cols: |
| return |
| with closing(shelve.open('/tmp/query_generator.shelve', writeback=True)) as store: |
| if self.db_type not in store: |
| store[self.db_type] = dict() |
| db_type_store = store[self.db_type] |
| if self.db_name not in db_type_store: |
| db_type_store[self.db_name] = dict() |
| db_store = db_type_store[self.db_name] |
| db_store[table.name] = [[col.name for col in cols] for cols in table.unique_cols] |
| |
| def load_unique_col_metadata(self, table): |
| with closing(shelve.open('/tmp/query_generator.shelve')) as store: |
| db_type_store = store.get(self.db_type) |
| if not db_type_store: |
| return |
| db_store = db_type_store.get(self.db_name) |
| if not db_store: |
| return |
| unique_col_names = db_store.get(table.name) |
| if not unique_col_names: |
| return |
| unique_cols = list() |
| for entry in unique_col_names: |
| col_names = set(entry) |
| cols = set((col for col in table.cols if col.name in col_names)) |
| if len(col_names) != len(cols): |
| raise Exception("Incorrect unique column data for %s" % table.name) |
| unique_cols.append(cols) |
| table.unique_cols = unique_cols |
| |
| def _fetch_primary_key_names(self, table_name): |
| """ |
| This must return a tuple of strings representing the primary keys of table_name, |
| or an empty tuple if there are no primary keys. |
| """ |
| # This is the base method. Since we haven't tested this on Oracle or Mysql or plan |
| # to implement this for those databases, the base method needs to return an empty |
| # tuple. |
| return () |
| |
| |
| class DbConnection(object): |
| |
| __metaclass__ = ABCMeta |
| |
| LOCK = Lock() |
| |
| PORT = None |
| USER_NAME = None |
| PASSWORD = None |
| |
| _CURSOR_CLASS = DbCursor |
| |
| def __init__(self, host_name="localhost", port=None, user_name=None, password=None, |
| db_name=None, log_sql=False): |
| self._host_name = host_name |
| self._port = port or self.PORT |
| self._user_name = user_name or self.USER_NAME |
| self._password = password or self.PASSWORD |
| self.db_name = db_name |
| self._conn = None |
| self._connect() |
| |
| if log_sql: |
| with DbConnection.LOCK: |
| sql_log_path = gettempdir() + '/sql_log_%s_%s.sql' \ |
| % (self.db_type.lower(), time()) |
| self.sql_log = open(sql_log_path, 'w') |
| link = gettempdir() + '/sql_log_%s.sql' % self.db_type.lower() |
| try: |
| unlink(link) |
| except OSError as e: |
| if 'No such file' not in str(e): |
| raise e |
| try: |
| symlink(sql_log_path, link) |
| except OSError as e: |
| raise e |
| else: |
| self.sql_log = None |
| |
| def _clone(self, db_name, **kwargs): |
| return type(self)(host_name=self._host_name, port=self._port, |
| user_name=self._user_name, password=self._password, db_name=db_name, **kwargs) |
| |
| def clone(self, db_name): |
| return self._clone(db_name) |
| |
| def __getattr__(self, attr): |
| if attr == "_conn": |
| raise AttributeError() |
| return getattr(self._conn, attr) |
| |
| def __setattr__(self, attr, value): |
| _conn = getattr(self, "_conn", None) |
| if not _conn or not hasattr(_conn, attr): |
| object.__setattr__(self, attr, value) |
| else: |
| setattr(self._conn, attr, value) |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| self.close(quiet=True) |
| |
| @property |
| def db_type(self): |
| return self._DB_TYPE |
| |
| @abstractmethod |
| def _connect(self): |
| pass |
| |
| def reconnect(self): |
| self.close(quiet=True) |
| self._connect() |
| |
| def close(self, quiet=False): |
| '''Close the underlying conn.''' |
| if not self._conn: |
| return |
| try: |
| self._conn.close() |
| self._conn = None |
| except Exception as e: |
| if quiet: |
| LOG.debug('Error closing connection: %s' % e) |
| else: |
| raise e |
| |
| def cursor(self): |
| return self._CURSOR_CLASS(self, self._conn.cursor()) |
| |
| @property |
| def supports_kill(self): |
| return False |
| |
| def kill(self): |
| '''Kill the current connection and any currently running queries associated with the |
| conn. |
| ''' |
| raise Exception('Killing connection is not supported') |
| |
| @property |
| def supports_index_creation(self): |
| return True |
| |
| @property |
| def data_types_are_implictly_nullable(self): |
| return False |
| |
| |
| class ImpalaCursor(DbCursor): |
| |
| PK_SEARCH_PATTERN = re.compile('PRIMARY KEY \((?P<keys>.*?)\)') |
| STORAGE_FORMATS_WITH_PRIMARY_KEYS = ('KUDU',) |
| |
| @classmethod |
| def make_insert_sql_from_data(cls, table, rows): |
| if not rows: |
| raise Exception('At least one row is required') |
| if not table.cols: |
| raise Exception('At least one col is required') |
| |
| sql = 'INSERT INTO %s VALUES ' % table.name |
| for row_idx, row in enumerate(rows): |
| if row_idx > 0: |
| sql += ', ' |
| sql += '(' |
| for col_idx, col in enumerate(table.cols): |
| if col_idx > 0: |
| sql += ', ' |
| val = row[col_idx] |
| if val is None: |
| sql += 'NULL' |
| elif issubclass(col.type, Timestamp): |
| sql += "'%s'" % val |
| elif issubclass(col.type, Char): |
| val = val.replace("'", "''") |
| sql += "'%s'" % val |
| else: |
| sql += str(val) |
| sql += ')' |
| return sql |
| |
| @property |
| def cluster(self): |
| return self.conn.cluster |
| |
| def _fetch_primary_key_names(self, table_name): |
| self.execute("SHOW CREATE TABLE {0}".format(table_name)) |
| # This returns 1 column with 1 multiline string row, resembling: |
| # |
| # CREATE TABLE db.table ( |
| # pk1 BIGINT, |
| # pk2 BIGINT, |
| # col BIGINT, |
| # PRIMARY KEY (pk1, pk2) |
| # ) |
| # |
| # Even a 1-column primary key will be shown as a PRIMARY KEY constraint, like: |
| # |
| # CREATE TABLE db.table ( |
| # pk1 BIGINT, |
| # col BIGINT, |
| # PRIMARY KEY (pk1) |
| # ) |
| (raw_result,) = self.fetchone() |
| search_result = ImpalaCursor.PK_SEARCH_PATTERN.search(raw_result) |
| if search_result is None: |
| return () |
| else: |
| return tuple(search_result.group("keys").split(", ")) |
| |
| def invalidate_metadata(self, table_name=None): |
| self.execute("INVALIDATE METADATA %s" % (table_name or "")) |
| |
| def drop_db(self, db_name): |
| '''This should not be called when connected to the database being dropped.''' |
| LOG.info("Dropping database %s", db_name) |
| self.execute('DROP DATABASE %s CASCADE' % db_name) |
| |
| def compute_stats(self, table_name=None): |
| if table_name: |
| self.execute("COMPUTE STATS %s" % table_name) |
| else: |
| for table_name in self.list_table_names(): |
| self.execute("COMPUTE STATS %s" % table_name) |
| |
| def create_table(self, table): |
| # The Hive metastore has a limit to the amount of schema it can store inline. |
| # Beyond this limit, the schema needs to be stored in HDFS and Hive is given a |
| # URL instead. |
| if table.storage_format == "AVRO" and not table.schema_location \ |
| and len(table.get_avro_schema()) > 4000: |
| self.upload_avro_schema(table) |
| super(ImpalaCursor, self).create_table(table) |
| |
| def upload_avro_schema(self, table): |
| self.ensure_schema_location(table) |
| avro_schema = table.get_avro_schema() |
| self.cluster.hdfs.create_client().write( |
| table.schema_location, data=avro_schema, overwrite=True) |
| |
| def ensure_schema_location(self, table): |
| if not table.schema_location: |
| db_name = self.conn.db_name or "default" |
| table.schema_location = '%s/%s.db/%s.avsc' % ( |
| self.cluster.hive.warehouse_dir, db_name, table.name) |
| |
| def ensure_storage_location(self, table): |
| if not table.storage_location: |
| db_name = self.conn.db_name or "default" |
| table.storage_location = '%s/%s.db/data/%s' % ( |
| self.cluster.hive.warehouse_dir, db_name, table.name) |
| |
| def make_create_table_sql(self, table): |
| sql = super(ImpalaCursor, self).make_create_table_sql(table) |
| |
| if table.primary_keys: |
| if table.storage_format in ImpalaCursor.STORAGE_FORMATS_WITH_PRIMARY_KEYS: |
| # IMPALA-4424 adds support for parametrizing the partitions; for now, on our |
| # small scale, this is ok, especially since the model is to migrate tables from |
| # Impala into Postgres anyway. 3 was chosen for the buckets because our |
| # minicluster tends to have 3 tablet servers, but otherwise it's arbitrary and |
| # provides valid syntax for creating Kudu tables in Impala. |
| sql += '\nPARTITION BY HASH ({col}) PARTITIONS 3'.format( |
| col=table.primary_key_names[0]) |
| else: |
| raise Exception( |
| 'table representation has primary keys {keys} but is not in a format that ' |
| 'supports them: {storage_format}'.format( |
| keys=str(table.primary_key_names), |
| storage_format=table.storage_format)) |
| elif table.storage_format in ImpalaCursor.STORAGE_FORMATS_WITH_PRIMARY_KEYS: |
| raise Exception( |
| 'table representation has storage format {storage_format} ' |
| 'but does not have any primary keys'.format( |
| storage_format=table.storage_format)) |
| |
| if table.storage_format != 'TEXTFILE': |
| sql += "\nSTORED AS " + table.storage_format |
| if table.storage_location: |
| sql = sql.replace("CREATE TABLE", "CREATE EXTERNAL TABLE") |
| sql += "\nLOCATION '%s'" % table.storage_location |
| if table.storage_format == 'AVRO': |
| if table.schema_location: |
| sql += "\nTBLPROPERTIES ('avro.schema.url' = '%s')" % table.schema_location |
| else: |
| avro_schema = table.get_avro_schema() |
| if len(avro_schema) > 4000: |
| raise Exception("Avro schema exceeds 4000 character limit. Create a file" |
| " containing the schema instead and set 'table.schema_location'.") |
| sql += "\nTBLPROPERTIES ('avro.schema.literal' = '%s')" % avro_schema |
| return sql |
| |
| def get_sql_for_data_type(self, data_type): |
| if issubclass(data_type, String): |
| return 'STRING' |
| return super(ImpalaCursor, self).get_sql_for_data_type(data_type) |
| |
| def close(self, quiet=False): |
| try: |
| # Explicitly close the operation to avoid issues like IMPALA-2562. |
| # This can be remove if https://github.com/cloudera/impyla/pull/142 is merged. |
| self._cursor.close_operation() |
| self._cursor.close() |
| except Exception as e: |
| if quiet: |
| LOG.debug('Error closing cursor: %s', e, exc_info=True) |
| else: |
| raise e |
| |
| |
| class ImpalaConnection(DbConnection): |
| |
| PORT = 21050 # For HS2 |
| |
| _DB_TYPE = IMPALA |
| _CURSOR_CLASS = ImpalaCursor |
| _KERBEROS_SERVICE_NAME = 'impala' |
| _NON_KERBEROS_AUTH_MECH = 'NOSASL' |
| |
| def __init__(self, use_kerberos=False, use_ssl=False, ca_cert=None, **kwargs): |
| self._use_kerberos = use_kerberos |
| self.cluster = None |
| self._use_ssl = use_ssl |
| self._ca_cert = ca_cert |
| DbConnection.__init__(self, **kwargs) |
| |
| def clone(self, db_name): |
| clone = self._clone(db_name, use_kerberos=self._use_kerberos) |
| clone.cluster = self.cluster |
| return clone |
| |
| @property |
| def data_types_are_implictly_nullable(self): |
| return True |
| |
| @property |
| def supports_index_creation(self): |
| return False |
| |
| def cursor(self): |
| cursor = super(ImpalaConnection, self).cursor() |
| cursor.arraysize = 1024 # Try to match the default batch size |
| return cursor |
| |
| def _connect(self): |
| self._conn = impala.dbapi.connect( |
| host=self._host_name, |
| port=self._port, |
| user=self._user_name, |
| password=self._password, |
| database=self.db_name, |
| timeout=(60 * 60), |
| auth_mechanism=('GSSAPI' if self._use_kerberos else self._NON_KERBEROS_AUTH_MECH), |
| kerberos_service_name=self._KERBEROS_SERVICE_NAME, |
| use_ssl=self._use_ssl, |
| ca_cert=self._ca_cert) |
| |
| |
| class HiveCursor(ImpalaCursor): |
| |
| def invalidate_metadata(self, table_name=None): |
| # There is no equivalent of "INVALIDATE METADATA" in Hive |
| pass |
| |
| def compute_stats(self, table_name=None): |
| if table_name: |
| self.execute("ANALYZE TABLE %s COMPUTE STATISTICS" % table_name) |
| self.execute("ANALYZE TABLE %s COMPUTE STATISTICS FOR COLUMNS" % table_name) |
| else: |
| for table_name in self.list_table_names(): |
| self.execute("ANALYZE TABLE %s COMPUTE STATISTICS" % table_name) |
| self.execute("ANALYZE TABLE %s COMPUTE STATISTICS FOR COLUMNS" % table_name) |
| |
| |
| class HiveConnection(ImpalaConnection): |
| |
| PORT = 11050 |
| |
| _DB_TYPE = HIVE |
| _CURSOR_CLASS = HiveCursor |
| _KERBEROS_SERVICE_NAME = 'hive' |
| _NON_KERBEROS_AUTH_MECH = 'PLAIN' |
| |
| |
| class PostgresqlCursor(DbCursor): |
| |
| TYPE_NAME_ALIASES = dict(DbCursor.TYPE_NAME_ALIASES) |
| TYPE_NAME_ALIASES.update({ |
| 'INTEGER': 'INT', |
| 'NUMERIC': 'DECIMAL', |
| 'REAL': 'FLOAT', |
| 'DOUBLE PRECISION': 'DOUBLE', |
| 'CHARACTER': 'CHAR', |
| 'CHARACTER VARYING': 'VARCHAR', |
| 'TIMESTAMP WITHOUT TIME ZONE': 'TIMESTAMP'}) |
| |
| EXACT_TYPES_TO_SQL = dict(DbCursor.EXACT_TYPES_TO_SQL) |
| EXACT_TYPES_TO_SQL.update({ |
| Double: 'DOUBLE PRECISION', |
| Float: 'REAL', |
| Int: 'INTEGER', |
| Timestamp: 'TIMESTAMP WITHOUT TIME ZONE', |
| TinyInt: 'SMALLINT'}) |
| |
| @classmethod |
| def make_insert_sql_from_data(cls, table, rows): |
| if not rows: |
| raise Exception('At least one row is required') |
| if not table.cols: |
| raise Exception('At least one col is required') |
| |
| sql = 'INSERT INTO %s VALUES ' % table.name |
| for row_idx, row in enumerate(rows): |
| if row_idx > 0: |
| sql += ', ' |
| sql += '(' |
| for col_idx, col in enumerate(table.cols): |
| if col_idx > 0: |
| sql += ', ' |
| val = row[col_idx] |
| if val is None: |
| sql += 'NULL' |
| elif issubclass(col.type, Timestamp): |
| sql += "TIMESTAMP '%s'" % val |
| elif issubclass(col.type, Char): |
| val = val.replace("'", "''") |
| val = val.replace('\\', '\\\\') |
| sql += "'%s'" % val |
| else: |
| sql += str(val) |
| sql += ')' |
| |
| return sql |
| |
| def make_list_db_names_sql(self): |
| return 'SELECT datname FROM pg_database' |
| |
| def make_list_table_names_sql(self): |
| return ''' |
| SELECT table_name |
| FROM information_schema.tables |
| WHERE table_schema = 'public' ''' |
| |
| def make_describe_table_sql(self, table_name): |
| # When doing a CREATE TABLE AS SELECT... a column may end up with type "Numeric". |
| # We'll assume that's a DOUBLE. |
| return ''' |
| SELECT column_name, |
| CASE data_type |
| WHEN 'character' THEN |
| data_type || '(' || character_maximum_length || ')' |
| WHEN 'character varying' THEN |
| data_type || '(' || character_maximum_length || ')' |
| WHEN 'numeric' THEN |
| data_type |
| || '(' |
| || numeric_precision |
| || ', ' |
| || numeric_scale |
| || ')' |
| ELSE data_type |
| END data_type |
| FROM information_schema.columns |
| WHERE table_name = '%s' |
| ORDER BY ordinal_position''' % \ |
| table_name |
| |
| def get_sql_for_data_type(self, data_type): |
| if issubclass(data_type, String): |
| return 'VARCHAR(%s)' % String.MAX |
| return super(PostgresqlCursor, self).get_sql_for_data_type(data_type) |
| |
| def _fetch_primary_key_names(self, table_name): |
| # see: |
| # https://www.postgresql.org/docs/9.5/static/infoschema-key-column-usage.html |
| # https://www.postgresql.org/docs/9.5/static/infoschema-table-constraints.html |
| sql = dedent(''' |
| SELECT |
| key_cols.column_name AS column_name |
| FROM |
| information_schema.key_column_usage key_cols, |
| information_schema.table_constraints table_constraints |
| WHERE |
| key_cols.constraint_catalog = table_constraints.constraint_catalog AND |
| key_cols.table_name = table_constraints.table_name AND |
| key_cols.constraint_name = table_constraints.constraint_name AND |
| table_constraints.constraint_type = 'PRIMARY KEY' AND |
| key_cols.table_name = '{table_name}' |
| ORDER BY key_cols.ordinal_position'''.format(table_name=table_name)) |
| self.execute(sql) |
| rows = self.fetchall() |
| return tuple(row[0] for row in rows) |
| |
| |
| class PostgresqlConnection(DbConnection): |
| |
| PORT = 5432 |
| USER_NAME = "postgres" |
| |
| _DB_TYPE = POSTGRESQL |
| _CURSOR_CLASS = PostgresqlCursor |
| |
| @property |
| def supports_kill(self): |
| return True |
| |
| def kill(self): |
| self._conn.cancel() |
| |
| def _connect(self): |
| try: |
| import psycopg2 |
| except ImportError as e: |
| if "No module named psycopg2" not in str(e): |
| raise e |
| import os |
| import subprocess |
| from tests.util.shell_util import shell |
| LOG.info("psycopg2 module not found; attempting to install it...") |
| pip_path = os.path.join(os.environ["IMPALA_HOME"], "infra", "python", "env", |
| "bin", "pip") |
| try: |
| shell(pip_path + " install psycopg2", stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT) |
| LOG.info("psycopg2 installation complete.") |
| except Exception as e: |
| LOG.error("psycopg2 installation failed. Try installing python-dev and" |
| " libpq-dev then try again.") |
| raise e |
| import psycopg2 |
| self._conn = psycopg2.connect( |
| host=self._host_name, |
| port=self._port, |
| user=self._user_name, |
| password=self._password, |
| database=self.db_name) |
| self._conn.autocommit = True |
| |
| |
| class MySQLConnection(DbConnection): |
| |
| PORT = 3306 |
| USER_NAME = "root" |
| |
| def __init__(self, client, conn, db_name=None): |
| DbConnection.__init__(self, client, conn, db_name=db_name) |
| self._session_id = self.execute_and_fetchall('SELECT connection_id()')[0][0] |
| |
| def _connect(self): |
| try: |
| import MySQLdb |
| except Exception: |
| print('Error importing MySQLdb. Please make sure it is installed. ' |
| 'See the README for details.') |
| raise |
| self._conn = MySQLdb.connect( |
| host=self._host_name, |
| port=self._port, |
| user=self._user_name, |
| passwd=self._password, |
| db=self.db_name) |
| self._conn.autocommit = True |
| |
| @property |
| def supports_kill_connection(self): |
| return True |
| |
| def kill_connection(self): |
| with self._client.open_connection(db_name=self.db_name) as conn: |
| conn.execute('KILL %s' % (self._session_id)) |
| |
| |
| class MySQLCursor(DbCursor): |
| |
| def describe_table(self, table_name): |
| '''Return a Table with table and col names always in lowercase.''' |
| rows = self.conn.execute_and_fetchall( |
| self.make_describe_table_sql(table_name)) |
| table = Table(table_name.lower()) |
| cols = table.cols # This is a copy |
| for row in rows: |
| col_name, data_type = row[:2] |
| if data_type == 'tinyint(1)': |
| # Just assume this is a boolean... |
| data_type = 'boolean' |
| if 'decimal' not in data_type and '(' in data_type: |
| # Strip the size of the data type |
| data_type = data_type[:data_type.index('(')] |
| cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type))) |
| table.cols = cols |
| return table |
| |
| def index_table(self, table_name): |
| table = self.describe_table(table_name) |
| with self.conn.open_cursor() as cursor: |
| for col in table.cols: |
| try: |
| cursor.execute('ALTER TABLE %s ADD INDEX (%s)' % (table_name, col.name)) |
| except Exception as e: |
| if 'Incorrect index name' not in str(e): |
| raise |
| # Some sort of MySQL bug... |
| LOG.warn('Could not create index on %s.%s: %s' % (table_name, col.name, e)) |
| |
| def make_create_table_sql(self, table): |
| table_sql = super(MySQLConnection, self).make_create_table_sql(table) |
| table_sql += ' ENGINE = MYISAM' |
| return table_sql |
| |
| |
| class OracleCursor(DbCursor): |
| |
| def make_list_table_names_sql(self): |
| return 'SELECT table_name FROM user_tables' |
| |
| def drop_db(self, db_name): |
| self.execute('DROP USER %s CASCADE' % db_name) |
| |
| def drop_db_if_exists(self, db_name): |
| '''This should not be called from a conn to the database being dropped.''' |
| try: |
| self.drop_db(db_name) |
| except Exception as e: |
| if 'ORA-01918' not in str(e): # Ignore if the user doesn't exist |
| raise |
| |
| def create_db(self, db_name): |
| self.execute( |
| 'CREATE USER %s IDENTIFIED BY %s DEFAULT TABLESPACE USERS' % (db_name, db_name)) |
| self.execute('GRANT ALL PRIVILEGES TO %s' % db_name) |
| |
| def make_describe_table_sql(self, table_name): |
| # Recreate the data types as defined in the model |
| return ''' |
| SELECT |
| column_name, |
| CASE |
| WHEN data_type = 'NUMBER' AND data_scale = 0 THEN |
| data_type || '(' || data_precision || ')' |
| WHEN data_type = 'NUMBER' THEN |
| data_type || '(' || data_precision || ', ' || data_scale || ')' |
| WHEN data_type IN ('CHAR', 'VARCHAR2') THEN |
| data_type || '(' || data_length || ')' |
| WHEN data_type LIKE 'TIMESTAMP%%' THEN |
| 'TIMESTAMP' |
| ELSE data_type |
| END |
| FROM all_tab_columns |
| WHERE owner = '%s' AND table_name = '%s' |
| ORDER BY column_id''' \ |
| % (self.schema.upper(), table_name.upper()) |
| |
| |
| class OracleConnection(DbConnection): |
| |
| PORT = 1521 |
| USER_NAME = 'system' |
| PASSWORD = 'oracle' |
| |
| _CURSOR_CLASS = OracleCursor |
| |
| def __init__(self, service='XE', **kwargs): |
| self._service = service |
| DbConnection.__init__(self, **kwargs) |
| self._conn.outputtypehandler = OracleTypeConverter.convert_type |
| self._conn.autocommit = True |
| |
| def clone(self, db_name): |
| return self._clone(db_name, service=self._service) |
| |
| @property |
| def schema(self): |
| return self.db_name or self._user_name |
| |
| def _connect(self): |
| try: |
| import cx_Oracle |
| except: |
| print('Error importing cx_Oracle. Please make sure it is installed. ' |
| 'See the README for details.') |
| raise |
| self._conn = cx_Oracle.connect('%(user)s/%(password)s@%(host)s:%(port)s/%(service)s' |
| % (self._user_name, self._password, self._host_name, self._port, self._service)) |
| |
| def cursor(self): |
| cursor = super(OracleConnection, self).cursor() |
| if self.db_name and self.db_name != self._user_name: |
| cursor.execute('ALTER SESSION SET CURRENT_SCHEMA = %s' % self.db_name) |
| return cursor |
| |
| |
| class OracleTypeConverter(object): |
| |
| __imported_types = False |
| __char_type = None |
| __number_type = None |
| |
| @classmethod |
| def convert_type(cls, cursor, name, default_type, size, precision, scale): |
| if not cls.__imported_types: |
| from cx_Oracle import FIXED_CHAR, NUMBER |
| cls.__char_type = FIXED_CHAR |
| cls.__number_type = NUMBER |
| cls.__imported_types = True |
| |
| if default_type == cls.__char_type and size == 1: |
| return cursor.var(str, 1, cursor.arraysize, outconverter=cls.convert_boolean) |
| if default_type == cls.__number_type and scale: |
| return cursor.var(str, 100, cursor.arraysize, outconverter=PyDecimal) |
| |
| @classmethod |
| def convert_boolean(cls, val): |
| if not val: |
| return |
| return val == 'T' |