blob: f5a256fb8d7c608d7c71a6fec8d58bfbe1df1d40 [file] [log] [blame]
"""Integration between SQLAlchemy and Presto.
Some code based on
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
which is released under the MIT license.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import re
import sqlalchemy
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.sql import text
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler, bindparam
from sqlalchemy.sql.compiler import SQLCompiler
from pyhive import presto
from pyhive.common import UniversalSet
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
reserved_words = UniversalSet()
_type_map = {
'boolean': types.Boolean,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
'real': types.Float,
'double': types.Float,
'varchar': types.String,
'timestamp': types.TIMESTAMP,
'date': types.DATE,
'varbinary': types.VARBINARY,
}
class PrestoCompiler(SQLCompiler):
def visit_char_length_func(self, fn, **kw):
return 'length{}'.format(self.function_argspec(fn, **kw))
class PrestoTypeCompiler(compiler.GenericTypeCompiler):
def visit_CLOB(self, type_, **kw):
raise ValueError("Presto does not support the CLOB column type.")
def visit_NCLOB(self, type_, **kw):
raise ValueError("Presto does not support the NCLOB column type.")
def visit_DATETIME(self, type_, **kw):
raise ValueError("Presto does not support the DATETIME column type.")
def visit_FLOAT(self, type_, **kw):
return 'DOUBLE'
def visit_TEXT(self, type_, **kw):
if type_.length:
return 'VARCHAR({:d})'.format(type_.length)
else:
return 'VARCHAR'
class PrestoDialect(default.DefaultDialect):
name = 'presto'
driver = 'rest'
paramstyle = 'pyformat'
preparer = PrestoIdentifierPreparer
statement_compiler = PrestoCompiler
supports_alter = False
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_multivalues_insert = True
supports_unicode_statements = True
supports_unicode_binds = True
supports_statement_cache = False
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
type_compiler = PrestoTypeCompiler
@classmethod
def dbapi(cls):
return presto
@classmethod
def import_dbapi(cls):
return presto
def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')
kwargs = {
'host': url.host,
'port': url.port or 8080,
'username': url.username,
'password': url.password
}
kwargs.update(url.query)
if len(db_parts) == 1:
kwargs['catalog'] = db_parts[0]
elif len(db_parts) == 2:
kwargs['catalog'] = db_parts[0]
kwargs['schema'] = db_parts[1]
else:
raise ValueError("Unexpected database format {}".format(url.database))
return [], kwargs
def get_schema_names(self, connection, **kw):
return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]
def _get_table_columns(self, connection, table_name, schema):
full_table = self.identifier_preparer.quote_identifier(table_name)
if schema:
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
try:
return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
except (presto.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Presto is that this
# error is raised when fetching the cursor's description rather than the initial execute
# call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped
# presto.DatabaseError here.
# Does the table exist?
msg = (
e.args[0].get('message') if e.args and isinstance(e.args[0], dict)
else e.args[0] if e.args and isinstance(e.args[0], str)
else None
)
regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name))
if msg and re.search(regex, msg, re.IGNORECASE):
raise exc.NoSuchTableError(table_name)
else:
raise
def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
except exc.NoSuchTableError:
return False
def get_columns(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
result = []
for row in rows:
try:
coltype = _type_map[row.Type]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column))
coltype = types.NullType
result.append({
'name': row.Column,
'type': coltype,
# newer Presto no longer includes this column
'nullable': getattr(row, 'Null', True),
'default': None,
})
return result
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Hive has no support for foreign keys.
return []
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Hive has no support for primary keys.
return []
def get_indexes(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
col_names = []
for row in rows:
part_key = 'Partition Key'
# Presto puts this information in one of 3 places depending on version
# - a boolean column named "Partition Key"
# - a string in the "Comment" column
# - a string in the "Extra" column
if sqlalchemy_version >= 1.4:
row = row._mapping
is_partition_key = (
(part_key in row and row[part_key])
or row['Comment'].startswith(part_key)
or ('Extra' in row and 'partition key' in row['Extra'])
)
if is_partition_key:
col_names.append(row['Column'])
if col_names:
return [{'name': 'partition', 'column_names': col_names, 'unique': False}]
else:
return []
def _get_default_schema_name(self, connection):
#'SELECT CURRENT_SCHEMA()'
return super()._get_default_schema_name(connection)
def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
# N.B. This is incorrect, if no schema is provided, the current/default schema should be used
# with a call to an overridden self._get_default_schema_name(connection), but I could not
# see how to implement that as there is no CURRENT_SCHEMA function
# default_schema = self._get_default_schema_name(connection)
if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
return [row.Table for row in connection.execute(text(query))]
def get_view_names(self, connection, schema=None, **kw):
if schema:
view_name_query = """
SELECT table_name
FROM information_schema.views
WHERE table_schema = :schema
"""
query = text(view_name_query).bindparams(
bindparam("schema", type_=types.Unicode)
)
else:
# N.B. This is incorrect, if no schema is provided, the current/default schema should
# be used with a call to self._get_default_schema_name(connection), but I could not
# see how to implement that
# default_schema = self._get_default_schema_name(connection)
view_name_query = """
SELECT table_name
FROM information_schema.views
"""
query = text(view_name_query)
result = connection.execute(query, dict(schema=schema))
return [row[0] for row in result]
def do_rollback(self, dbapi_connection):
# No transactions for Presto
pass
def _check_unicode_returns(self, connection, additional_tests=None):
# requests gives back Unicode strings
return True
def _check_unicode_description(self, connection):
# requests gives back Unicode strings
return True