| """Integration between SQLAlchemy and Presto. |
| |
| Some code based on |
| https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py |
| which is released under the MIT license. |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import unicode_literals |
| |
| import re |
| import sqlalchemy |
| from sqlalchemy import exc |
| from sqlalchemy import types |
| from sqlalchemy import util |
| # TODO shouldn't use mysql type |
| from sqlalchemy.sql import text |
| try: |
| from sqlalchemy.databases import mysql |
| mysql_tinyinteger = mysql.MSTinyInteger |
| except ImportError: |
| # Required for SQLAlchemy>=2.0 |
| from sqlalchemy.dialects import mysql |
| mysql_tinyinteger = mysql.base.MSTinyInteger |
| from sqlalchemy.engine import default |
| from sqlalchemy.sql import compiler, bindparam |
| from sqlalchemy.sql.compiler import SQLCompiler |
| |
| from pyhive import presto |
| from pyhive.common import UniversalSet |
| |
| sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) |
| |
| class PrestoIdentifierPreparer(compiler.IdentifierPreparer): |
| # Just quote everything to make things simpler / easier to upgrade |
| reserved_words = UniversalSet() |
| |
| |
| _type_map = { |
| 'boolean': types.Boolean, |
| 'tinyint': mysql_tinyinteger, |
| 'smallint': types.SmallInteger, |
| 'integer': types.Integer, |
| 'bigint': types.BigInteger, |
| 'real': types.Float, |
| 'double': types.Float, |
| 'varchar': types.String, |
| 'timestamp': types.TIMESTAMP, |
| 'date': types.DATE, |
| 'varbinary': types.VARBINARY, |
| } |
| |
| |
| class PrestoCompiler(SQLCompiler): |
| def visit_char_length_func(self, fn, **kw): |
| return 'length{}'.format(self.function_argspec(fn, **kw)) |
| |
| |
| class PrestoTypeCompiler(compiler.GenericTypeCompiler): |
| def visit_CLOB(self, type_, **kw): |
| raise ValueError("Presto does not support the CLOB column type.") |
| |
| def visit_NCLOB(self, type_, **kw): |
| raise ValueError("Presto does not support the NCLOB column type.") |
| |
| def visit_DATETIME(self, type_, **kw): |
| raise ValueError("Presto does not support the DATETIME column type.") |
| |
| def visit_FLOAT(self, type_, **kw): |
| return 'DOUBLE' |
| |
| def visit_TEXT(self, type_, **kw): |
| if type_.length: |
| return 'VARCHAR({:d})'.format(type_.length) |
| else: |
| return 'VARCHAR' |
| |
| |
| class PrestoDialect(default.DefaultDialect): |
| name = 'presto' |
| driver = 'rest' |
| paramstyle = 'pyformat' |
| preparer = PrestoIdentifierPreparer |
| statement_compiler = PrestoCompiler |
| supports_alter = False |
| supports_pk_autoincrement = False |
| supports_default_values = False |
| supports_empty_insert = False |
| supports_multivalues_insert = True |
| supports_unicode_statements = True |
| supports_unicode_binds = True |
| supports_statement_cache = False |
| returns_unicode_strings = True |
| description_encoding = None |
| supports_native_boolean = True |
| type_compiler = PrestoTypeCompiler |
| |
| @classmethod |
| def dbapi(cls): |
| return presto |
| |
| @classmethod |
| def import_dbapi(cls): |
| return presto |
| |
| def create_connect_args(self, url): |
| db_parts = (url.database or 'hive').split('/') |
| kwargs = { |
| 'host': url.host, |
| 'port': url.port or 8080, |
| 'username': url.username, |
| 'password': url.password |
| } |
| kwargs.update(url.query) |
| if len(db_parts) == 1: |
| kwargs['catalog'] = db_parts[0] |
| elif len(db_parts) == 2: |
| kwargs['catalog'] = db_parts[0] |
| kwargs['schema'] = db_parts[1] |
| else: |
| raise ValueError("Unexpected database format {}".format(url.database)) |
| return [], kwargs |
| |
| def get_schema_names(self, connection, **kw): |
| return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))] |
| |
| def _get_table_columns(self, connection, table_name, schema): |
| full_table = self.identifier_preparer.quote_identifier(table_name) |
| if schema: |
| full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table |
| try: |
| return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table))) |
| except (presto.DatabaseError, exc.DatabaseError) as e: |
| # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which |
| # it successfully does in the Hive version. The difference with Presto is that this |
| # error is raised when fetching the cursor's description rather than the initial execute |
| # call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped |
| # presto.DatabaseError here. |
| # Does the table exist? |
| msg = ( |
| e.args[0].get('message') if e.args and isinstance(e.args[0], dict) |
| else e.args[0] if e.args and isinstance(e.args[0], str) |
| else None |
| ) |
| regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name)) |
| if msg and re.search(regex, msg, re.IGNORECASE): |
| raise exc.NoSuchTableError(table_name) |
| else: |
| raise |
| |
| def has_table(self, connection, table_name, schema=None, **kw): |
| try: |
| self._get_table_columns(connection, table_name, schema) |
| return True |
| except exc.NoSuchTableError: |
| return False |
| |
| def get_columns(self, connection, table_name, schema=None, **kw): |
| rows = self._get_table_columns(connection, table_name, schema) |
| result = [] |
| for row in rows: |
| try: |
| coltype = _type_map[row.Type] |
| except KeyError: |
| util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column)) |
| coltype = types.NullType |
| result.append({ |
| 'name': row.Column, |
| 'type': coltype, |
| # newer Presto no longer includes this column |
| 'nullable': getattr(row, 'Null', True), |
| 'default': None, |
| }) |
| return result |
| |
| def get_foreign_keys(self, connection, table_name, schema=None, **kw): |
| # Hive has no support for foreign keys. |
| return [] |
| |
| def get_pk_constraint(self, connection, table_name, schema=None, **kw): |
| # Hive has no support for primary keys. |
| return [] |
| |
| def get_indexes(self, connection, table_name, schema=None, **kw): |
| rows = self._get_table_columns(connection, table_name, schema) |
| col_names = [] |
| for row in rows: |
| part_key = 'Partition Key' |
| # Presto puts this information in one of 3 places depending on version |
| # - a boolean column named "Partition Key" |
| # - a string in the "Comment" column |
| # - a string in the "Extra" column |
| if sqlalchemy_version >= 1.4: |
| row = row._mapping |
| is_partition_key = ( |
| (part_key in row and row[part_key]) |
| or row['Comment'].startswith(part_key) |
| or ('Extra' in row and 'partition key' in row['Extra']) |
| ) |
| if is_partition_key: |
| col_names.append(row['Column']) |
| if col_names: |
| return [{'name': 'partition', 'column_names': col_names, 'unique': False}] |
| else: |
| return [] |
| |
| def _get_default_schema_name(self, connection): |
| #'SELECT CURRENT_SCHEMA()' |
| return super()._get_default_schema_name(connection) |
| |
| def get_table_names(self, connection, schema=None, **kw): |
| query = 'SHOW TABLES' |
| # N.B. This is incorrect, if no schema is provided, the current/default schema should be used |
| # with a call to an overridden self._get_default_schema_name(connection), but I could not |
| # see how to implement that as there is no CURRENT_SCHEMA function |
| # default_schema = self._get_default_schema_name(connection) |
| |
| if schema: |
| query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) |
| return [row.Table for row in connection.execute(text(query))] |
| |
| def get_view_names(self, connection, schema=None, **kw): |
| if schema: |
| view_name_query = """ |
| SELECT table_name |
| FROM information_schema.views |
| WHERE table_schema = :schema |
| """ |
| query = text(view_name_query).bindparams( |
| bindparam("schema", type_=types.Unicode) |
| ) |
| else: |
| # N.B. This is incorrect, if no schema is provided, the current/default schema should |
| # be used with a call to self._get_default_schema_name(connection), but I could not |
| # see how to implement that |
| # default_schema = self._get_default_schema_name(connection) |
| view_name_query = """ |
| SELECT table_name |
| FROM information_schema.views |
| """ |
| query = text(view_name_query) |
| |
| result = connection.execute(query, dict(schema=schema)) |
| return [row[0] for row in result] |
| |
| def do_rollback(self, dbapi_connection): |
| # No transactions for Presto |
| pass |
| |
| def _check_unicode_returns(self, connection, additional_tests=None): |
| # requests gives back Unicode strings |
| return True |
| |
| def _check_unicode_description(self, connection): |
| # requests gives back Unicode strings |
| return True |