blob: db89d57b5104db3dbd4335d081ec259d20dbcb84 [file] [log] [blame]
# coding: utf-8
from __future__ import absolute_import
from __future__ import unicode_literals
import abc
import re
import contextlib
import functools
import pytest
import sqlalchemy
from builtins import object
from future.utils import with_metaclass
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.schema import Index
from sqlalchemy.schema import MetaData
from sqlalchemy.schema import Table
from sqlalchemy.sql import expression, text
from sqlalchemy import String
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
def with_engine_connection(fn):
"""Pass a connection to the given function and handle cleanup.
The connection is taken from ``self.create_engine()``.
"""
@functools.wraps(fn)
def wrapped_fn(self, *args, **kwargs):
engine = self.create_engine()
try:
with contextlib.closing(engine.connect()) as connection:
fn(self, engine, connection, *args, **kwargs)
finally:
engine.dispose()
return wrapped_fn
def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks):
if sqlalchemy_version >= 1.4:
insp = sqlalchemy.inspect(engine)
insp.reflect_table(
table,
include_columns=include_columns,
exclude_columns=exclude_columns,
resolve_fks=resolve_fks,
)
else:
engine.dialect.reflecttable(
connection, table, include_columns=include_columns,
exclude_columns=exclude_columns, resolve_fks=resolve_fks)
class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)):
@with_engine_connection
def test_basic_query(self, engine, connection):
rows = connection.execute(text('SELECT * FROM one_row')).fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name
self.assertEqual(len(rows[0]), 1)
@with_engine_connection
def test_one_row_complex_null(self, engine, connection):
one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine)
rows = connection.execute(one_row_complex_null.select()).fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(list(rows[0]), [None] * len(rows[0]))
@with_engine_connection
def test_reflect_no_such_table(self, engine, connection):
"""reflecttable should throw an exception on an invalid table"""
self.assertRaises(
NoSuchTableError,
lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine))
self.assertRaises(
NoSuchTableError,
lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine))
@with_engine_connection
def test_reflect_include_columns(self, engine, connection):
"""When passed include_columns, reflecttable should filter out other columns"""
one_row_complex = Table('one_row_complex', MetaData())
reflect_table(engine, connection, one_row_complex, include_columns=['int'],
exclude_columns=[], resolve_fks=True)
self.assertEqual(len(one_row_complex.c), 1)
self.assertIsNotNone(one_row_complex.c.int)
self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint)
@with_engine_connection
def test_reflect_with_schema(self, engine, connection):
dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine)
self.assertEqual(len(dummy.c), 1)
self.assertIsNotNone(dummy.c.a)
@pytest.mark.filterwarnings('default:Omitting:sqlalchemy.exc.SAWarning')
@with_engine_connection
def test_reflect_partitions(self, engine, connection):
"""reflecttable should get the partition column as an index"""
many_rows = Table('many_rows', MetaData(), autoload_with=engine)
self.assertEqual(len(many_rows.c), 2)
self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))
many_rows = Table('many_rows', MetaData())
reflect_table(engine, connection, many_rows, include_columns=['a'],
exclude_columns=[], resolve_fks=True)
self.assertEqual(len(many_rows.c), 1)
self.assertFalse(many_rows.c.a.index)
self.assertFalse(many_rows.indexes)
many_rows = Table('many_rows', MetaData())
reflect_table(engine, connection, many_rows, include_columns=['b'],
exclude_columns=[], resolve_fks=True)
self.assertEqual(len(many_rows.c), 1)
self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))
@with_engine_connection
def test_unicode(self, engine, connection):
"""Verify that unicode strings make it through SQLAlchemy and the backend"""
unicode_str = "中文"
one_row = Table('one_row', MetaData())
if sqlalchemy_version >= 1.4:
returned_str = connection.execute(sqlalchemy.select(
expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar()
else:
returned_str = connection.execute(sqlalchemy.select([
expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar()
self.assertEqual(returned_str, unicode_str)
@with_engine_connection
def test_reflect_schemas(self, engine, connection):
insp = sqlalchemy.inspect(engine)
schemas = insp.get_schema_names()
self.assertIn('pyhive_test_database', schemas)
self.assertIn('default', schemas)
@with_engine_connection
def test_get_table_names(self, engine, connection):
meta = MetaData()
meta.reflect(bind=engine)
self.assertIn('one_row', meta.tables)
self.assertIn('one_row_complex', meta.tables)
insp = sqlalchemy.inspect(engine)
self.assertIn(
'dummy_table',
insp.get_table_names(schema='pyhive_test_database'),
)
@with_engine_connection
def test_has_table(self, engine, connection):
if sqlalchemy_version >= 1.4:
insp = sqlalchemy.inspect(engine)
self.assertTrue(insp.has_table("one_row"))
self.assertFalse(insp.has_table("this_table_does_not_exist"))
else:
self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists())
@with_engine_connection
def test_char_length(self, engine, connection):
one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
if sqlalchemy_version >= 1.4:
result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar()
else:
result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar()
self.assertEqual(result, len('a string'))