| """Integration between SQLAlchemy and Hive. |
| |
| Some code based on |
| https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py |
| which is released under the MIT license. |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import unicode_literals |
| |
| import datetime |
| import decimal |
| import logging |
| |
| import re |
| from sqlalchemy import exc |
| from sqlalchemy.sql import text |
| try: |
| from sqlalchemy import processors |
| except ImportError: |
| # Required for SQLAlchemy>=2.0 |
| from sqlalchemy.engine import processors |
| from sqlalchemy import types |
| from sqlalchemy import util |
| # TODO shouldn't use mysql type |
| try: |
| from sqlalchemy.databases import mysql |
| mysql_tinyinteger = mysql.MSTinyInteger |
| except ImportError: |
| # Required for SQLAlchemy>2.0 |
| from sqlalchemy.dialects import mysql |
| mysql_tinyinteger = mysql.base.MSTinyInteger |
| from sqlalchemy.engine import default |
| from sqlalchemy.sql import compiler |
| from sqlalchemy.sql.compiler import SQLCompiler |
| |
| from pyhive import hive |
| from pyhive.common import UniversalSet |
| |
| from dateutil.parser import parse |
| from decimal import Decimal |
| |
| _logger = logging.getLogger(__name__) |
| |
| class HiveStringTypeBase(types.TypeDecorator): |
| """Translates strings returned by Thrift into something else""" |
| impl = types.String |
| |
| def process_bind_param(self, value, dialect): |
| raise NotImplementedError("Writing to Hive not supported") |
| |
| |
| class HiveDate(HiveStringTypeBase): |
| """Translates date strings to date objects""" |
| impl = types.DATE |
| |
| def process_result_value(self, value, dialect): |
| return processors.str_to_date(value) |
| |
| def result_processor(self, dialect, coltype): |
| def process(value): |
| if isinstance(value, datetime.datetime): |
| return value.date() |
| elif isinstance(value, datetime.date): |
| return value |
| elif value is not None: |
| return parse(value).date() |
| else: |
| return None |
| |
| return process |
| |
| def adapt(self, impltype, **kwargs): |
| return self.impl |
| |
| |
| class HiveTimestamp(HiveStringTypeBase): |
| """Translates timestamp strings to datetime objects""" |
| impl = types.TIMESTAMP |
| |
| def process_result_value(self, value, dialect): |
| return processors.str_to_datetime(value) |
| |
| def result_processor(self, dialect, coltype): |
| def process(value): |
| if isinstance(value, datetime.datetime): |
| return value |
| elif value is not None: |
| return parse(value) |
| else: |
| return None |
| |
| return process |
| |
| def adapt(self, impltype, **kwargs): |
| return self.impl |
| |
| |
| class HiveDecimal(HiveStringTypeBase): |
| """Translates strings to decimals""" |
| impl = types.DECIMAL |
| |
| def process_result_value(self, value, dialect): |
| if value is not None: |
| return decimal.Decimal(value) |
| else: |
| return None |
| |
| def result_processor(self, dialect, coltype): |
| def process(value): |
| if isinstance(value, Decimal): |
| return value |
| elif value is not None: |
| return Decimal(value) |
| else: |
| return None |
| |
| return process |
| |
| def adapt(self, impltype, **kwargs): |
| return self.impl |
| |
| |
| class HiveIdentifierPreparer(compiler.IdentifierPreparer): |
| # Just quote everything to make things simpler / easier to upgrade |
| reserved_words = UniversalSet() |
| |
| def __init__(self, dialect): |
| super(HiveIdentifierPreparer, self).__init__( |
| dialect, |
| initial_quote='`', |
| ) |
| |
| |
| _type_map = { |
| 'boolean': types.Boolean, |
| 'tinyint': mysql_tinyinteger, |
| 'smallint': types.SmallInteger, |
| 'int': types.Integer, |
| 'bigint': types.BigInteger, |
| 'float': types.Float, |
| 'double': types.Float, |
| 'string': types.String, |
| 'varchar': types.String, |
| 'char': types.String, |
| 'date': HiveDate, |
| 'timestamp': HiveTimestamp, |
| 'binary': types.String, |
| 'array': types.String, |
| 'map': types.String, |
| 'struct': types.String, |
| 'uniontype': types.String, |
| 'decimal': HiveDecimal, |
| } |
| |
| |
| class HiveCompiler(SQLCompiler): |
| def visit_concat_op_binary(self, binary, operator, **kw): |
| return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) |
| |
| def visit_insert(self, *args, **kwargs): |
| result = super(HiveCompiler, self).visit_insert(*args, **kwargs) |
| # Massage the result into Hive's format |
| # INSERT INTO `pyhive_test_database`.`test_table` (`a`) SELECT ... |
| # => |
| # INSERT INTO TABLE `pyhive_test_database`.`test_table` SELECT ... |
| regex = r'^(INSERT INTO) ([^\s]+) \([^\)]*\)' |
| assert re.search(regex, result), "Unexpected visit_insert result: {}".format(result) |
| return re.sub(regex, r'\1 TABLE \2', result) |
| |
| def visit_column(self, *args, **kwargs): |
| result = super(HiveCompiler, self).visit_column(*args, **kwargs) |
| dot_count = result.count('.') |
| assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(result) |
| if dot_count == 2: |
| # we have something of the form schema.table.column |
| # hive doesn't like the schema in front, so chop it out |
| result = result[result.index('.') + 1:] |
| return result |
| |
| def visit_char_length_func(self, fn, **kw): |
| return 'length{}'.format(self.function_argspec(fn, **kw)) |
| |
| |
| class HiveTypeCompiler(compiler.GenericTypeCompiler): |
| def visit_INTEGER(self, type_): |
| return 'INT' |
| |
| def visit_NUMERIC(self, type_): |
| return 'DECIMAL' |
| |
| def visit_CHAR(self, type_): |
| return 'STRING' |
| |
| def visit_VARCHAR(self, type_): |
| return 'STRING' |
| |
| def visit_NCHAR(self, type_): |
| return 'STRING' |
| |
| def visit_TEXT(self, type_): |
| return 'STRING' |
| |
| def visit_CLOB(self, type_): |
| return 'STRING' |
| |
| def visit_BLOB(self, type_): |
| return 'BINARY' |
| |
| def visit_TIME(self, type_): |
| return 'TIMESTAMP' |
| |
| def visit_DATE(self, type_): |
| return 'TIMESTAMP' |
| |
| def visit_DATETIME(self, type_): |
| return 'TIMESTAMP' |
| |
| |
| class HiveExecutionContext(default.DefaultExecutionContext): |
| """This is pretty much the same as SQLiteExecutionContext to work around the same issue. |
| |
| http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#dotted-column-names |
| |
| engine = create_engine('hive://...', execution_options={'hive_raw_colnames': True}) |
| """ |
| |
| @util.memoized_property |
| def _preserve_raw_colnames(self): |
| # Ideally, this would also gate on hive.resultset.use.unique.column.names |
| return self.execution_options.get('hive_raw_colnames', False) |
| |
| def _translate_colname(self, colname): |
| # Adjust for dotted column names. |
| # When hive.resultset.use.unique.column.names is true (the default), Hive returns column |
| # names as "tablename.colname" in cursor.description. |
| if not self._preserve_raw_colnames and '.' in colname: |
| return colname.split('.')[-1], colname |
| else: |
| return colname, None |
| |
| |
| class HiveDialect(default.DefaultDialect): |
| name = 'hive' |
| driver = 'thrift' |
| execution_ctx_cls = HiveExecutionContext |
| preparer = HiveIdentifierPreparer |
| statement_compiler = HiveCompiler |
| supports_views = True |
| supports_alter = True |
| supports_pk_autoincrement = False |
| supports_default_values = False |
| supports_empty_insert = False |
| supports_native_decimal = True |
| supports_native_boolean = True |
| supports_unicode_statements = True |
| supports_unicode_binds = True |
| returns_unicode_strings = True |
| description_encoding = None |
| supports_multivalues_insert = True |
| type_compiler = HiveTypeCompiler |
| supports_sane_rowcount = False |
| supports_statement_cache = False |
| |
| @classmethod |
| def dbapi(cls): |
| return hive |
| |
| @classmethod |
| def import_dbapi(cls): |
| return hive |
| |
| def create_connect_args(self, url): |
| kwargs = { |
| 'host': url.host, |
| 'port': url.port or 10000, |
| 'username': url.username, |
| 'password': url.password, |
| 'database': url.database or 'default', |
| } |
| kwargs.update(url.query) |
| return [], kwargs |
| |
| def get_schema_names(self, connection, **kw): |
| # Equivalent to SHOW DATABASES |
| return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] |
| |
| def get_view_names(self, connection, schema=None, **kw): |
| # Hive does not provide functionality to query tableType |
| # This allows reflection to not crash at the cost of being inaccurate |
| return self.get_table_names(connection, schema, **kw) |
| |
| def _get_table_columns(self, connection, table_name, schema): |
| full_table = table_name |
| if schema: |
| full_table = schema + '.' + table_name |
| # TODO using TGetColumnsReq hangs after sending TFetchResultsReq. |
| # Using DESCRIBE works but is uglier. |
| try: |
| # This needs the table name to be unescaped (no backticks). |
| rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() |
| except exc.OperationalError as e: |
| # Does the table exist? |
| regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' |
| regex = regex_fmt.format(re.escape(full_table)) |
| if re.search(regex, e.args[0]): |
| raise exc.NoSuchTableError(full_table) |
| else: |
| raise |
| else: |
| # Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist |
| regex = r'Table .* does not exist' |
| if len(rows) == 1 and re.match(regex, rows[0].col_name): |
| raise exc.NoSuchTableError(full_table) |
| return rows |
| |
| def has_table(self, connection, table_name, schema=None, **kw): |
| try: |
| self._get_table_columns(connection, table_name, schema) |
| return True |
| except exc.NoSuchTableError: |
| return False |
| |
| def get_columns(self, connection, table_name, schema=None, **kw): |
| rows = self._get_table_columns(connection, table_name, schema) |
| # Strip whitespace |
| rows = [[col.strip() if col else None for col in row] for row in rows] |
| # Filter out empty rows and comment |
| rows = [row for row in rows if row[0] and row[0] != '# col_name'] |
| result = [] |
| for (col_name, col_type, _comment) in rows: |
| if col_name == '# Partition Information': |
| break |
| # Take out the more detailed type information |
| # e.g. 'map<int,int>' -> 'map' |
| # 'decimal(10,1)' -> decimal |
| col_type = re.search(r'^\w+', col_type).group(0) |
| try: |
| coltype = _type_map[col_type] |
| except KeyError: |
| util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name)) |
| coltype = types.NullType |
| |
| result.append({ |
| 'name': col_name, |
| 'type': coltype, |
| 'nullable': True, |
| 'default': None, |
| }) |
| return result |
| |
| def get_foreign_keys(self, connection, table_name, schema=None, **kw): |
| # Hive has no support for foreign keys. |
| return [] |
| |
| def get_pk_constraint(self, connection, table_name, schema=None, **kw): |
| # Hive has no support for primary keys. |
| return [] |
| |
| def get_indexes(self, connection, table_name, schema=None, **kw): |
| rows = self._get_table_columns(connection, table_name, schema) |
| # Strip whitespace |
| rows = [[col.strip() if col else None for col in row] for row in rows] |
| # Filter out empty rows and comment |
| rows = [row for row in rows if row[0] and row[0] != '# col_name'] |
| for i, (col_name, _col_type, _comment) in enumerate(rows): |
| if col_name == '# Partition Information': |
| break |
| # Handle partition columns |
| col_names = [] |
| for col_name, _col_type, _comment in rows[i + 1:]: |
| col_names.append(col_name) |
| if col_names: |
| return [{'name': 'partition', 'column_names': col_names, 'unique': False}] |
| else: |
| return [] |
| |
| def get_table_names(self, connection, schema=None, **kw): |
| query = 'SHOW TABLES' |
| if schema: |
| query += ' IN ' + self.identifier_preparer.quote_identifier(schema) |
| |
| table_names = [] |
| |
| for row in connection.execute(text(query)): |
| # Hive returns 1 columns |
| if len(row) == 1: |
| table_names.append(row[0]) |
| # Spark SQL returns 3 columns |
| elif len(row) == 3: |
| table_names.append(row[1]) |
| else: |
| _logger.warning("Unexpected number of columns in SHOW TABLES result: {}".format(len(row))) |
| table_names.append('UNKNOWN') |
| |
| return table_names |
| |
| def do_rollback(self, dbapi_connection): |
| # No transactions for Hive |
| pass |
| |
| def _check_unicode_returns(self, connection, additional_tests=None): |
| # We decode everything as UTF-8 |
| return True |
| |
| def _check_unicode_description(self, connection): |
| # We decode everything as UTF-8 |
| return True |
| |
| |
| class HiveHTTPDialect(HiveDialect): |
| |
| name = "hive" |
| scheme = "http" |
| driver = "rest" |
| |
| def create_connect_args(self, url): |
| kwargs = { |
| "host": url.host, |
| "port": url.port or 10000, |
| "scheme": self.scheme, |
| "username": url.username or None, |
| "password": url.password or None, |
| "database": url.database or "default", |
| } |
| if url.query: |
| kwargs.update(url.query) |
| return [], kwargs |
| return ([], kwargs) |
| |
| |
| class HiveHTTPSDialect(HiveHTTPDialect): |
| |
| name = "hive" |
| scheme = "https" |