| # coding: utf-8 |
| from __future__ import absolute_import |
| from __future__ import unicode_literals |
| |
| import abc |
| import re |
| import contextlib |
| import functools |
| |
| import pytest |
| import sqlalchemy |
| from builtins import object |
| from future.utils import with_metaclass |
| from sqlalchemy.exc import NoSuchTableError |
| from sqlalchemy.schema import Index |
| from sqlalchemy.schema import MetaData |
| from sqlalchemy.schema import Table |
| from sqlalchemy.sql import expression, text |
| from sqlalchemy import String |
| |
| sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) |
| |
| def with_engine_connection(fn): |
| """Pass a connection to the given function and handle cleanup. |
| |
| The connection is taken from ``self.create_engine()``. |
| """ |
| @functools.wraps(fn) |
| def wrapped_fn(self, *args, **kwargs): |
| engine = self.create_engine() |
| try: |
| with contextlib.closing(engine.connect()) as connection: |
| fn(self, engine, connection, *args, **kwargs) |
| finally: |
| engine.dispose() |
| return wrapped_fn |
| |
| def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks): |
| if sqlalchemy_version >= 1.4: |
| insp = sqlalchemy.inspect(engine) |
| insp.reflect_table( |
| table, |
| include_columns=include_columns, |
| exclude_columns=exclude_columns, |
| resolve_fks=resolve_fks, |
| ) |
| else: |
| engine.dialect.reflecttable( |
| connection, table, include_columns=include_columns, |
| exclude_columns=exclude_columns, resolve_fks=resolve_fks) |
| |
| |
| class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)): |
| @with_engine_connection |
| def test_basic_query(self, engine, connection): |
| rows = connection.execute(text('SELECT * FROM one_row')).fetchall() |
| self.assertEqual(len(rows), 1) |
| self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name |
| self.assertEqual(len(rows[0]), 1) |
| |
| @with_engine_connection |
| def test_one_row_complex_null(self, engine, connection): |
| one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine) |
| rows = connection.execute(one_row_complex_null.select()).fetchall() |
| self.assertEqual(len(rows), 1) |
| self.assertEqual(list(rows[0]), [None] * len(rows[0])) |
| |
| @with_engine_connection |
| def test_reflect_no_such_table(self, engine, connection): |
| """reflecttable should throw an exception on an invalid table""" |
| self.assertRaises( |
| NoSuchTableError, |
| lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) |
| self.assertRaises( |
| NoSuchTableError, |
| lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine)) |
| |
| @with_engine_connection |
| def test_reflect_include_columns(self, engine, connection): |
| """When passed include_columns, reflecttable should filter out other columns""" |
| |
| one_row_complex = Table('one_row_complex', MetaData()) |
| reflect_table(engine, connection, one_row_complex, include_columns=['int'], |
| exclude_columns=[], resolve_fks=True) |
| |
| self.assertEqual(len(one_row_complex.c), 1) |
| self.assertIsNotNone(one_row_complex.c.int) |
| self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint) |
| |
| @with_engine_connection |
| def test_reflect_with_schema(self, engine, connection): |
| dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine) |
| self.assertEqual(len(dummy.c), 1) |
| self.assertIsNotNone(dummy.c.a) |
| |
| @pytest.mark.filterwarnings('default:Omitting:sqlalchemy.exc.SAWarning') |
| @with_engine_connection |
| def test_reflect_partitions(self, engine, connection): |
| """reflecttable should get the partition column as an index""" |
| many_rows = Table('many_rows', MetaData(), autoload_with=engine) |
| self.assertEqual(len(many_rows.c), 2) |
| self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) |
| |
| many_rows = Table('many_rows', MetaData()) |
| reflect_table(engine, connection, many_rows, include_columns=['a'], |
| exclude_columns=[], resolve_fks=True) |
| |
| self.assertEqual(len(many_rows.c), 1) |
| self.assertFalse(many_rows.c.a.index) |
| self.assertFalse(many_rows.indexes) |
| |
| many_rows = Table('many_rows', MetaData()) |
| reflect_table(engine, connection, many_rows, include_columns=['b'], |
| exclude_columns=[], resolve_fks=True) |
| |
| self.assertEqual(len(many_rows.c), 1) |
| self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) |
| |
| @with_engine_connection |
| def test_unicode(self, engine, connection): |
| """Verify that unicode strings make it through SQLAlchemy and the backend""" |
| unicode_str = "中文" |
| one_row = Table('one_row', MetaData()) |
| |
| if sqlalchemy_version >= 1.4: |
| returned_str = connection.execute(sqlalchemy.select( |
| expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar() |
| else: |
| returned_str = connection.execute(sqlalchemy.select([ |
| expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar() |
| |
| self.assertEqual(returned_str, unicode_str) |
| |
| @with_engine_connection |
| def test_reflect_schemas(self, engine, connection): |
| insp = sqlalchemy.inspect(engine) |
| schemas = insp.get_schema_names() |
| self.assertIn('pyhive_test_database', schemas) |
| self.assertIn('default', schemas) |
| |
| @with_engine_connection |
| def test_get_table_names(self, engine, connection): |
| meta = MetaData() |
| meta.reflect(bind=engine) |
| self.assertIn('one_row', meta.tables) |
| self.assertIn('one_row_complex', meta.tables) |
| |
| insp = sqlalchemy.inspect(engine) |
| self.assertIn( |
| 'dummy_table', |
| insp.get_table_names(schema='pyhive_test_database'), |
| ) |
| |
| @with_engine_connection |
| def test_has_table(self, engine, connection): |
| if sqlalchemy_version >= 1.4: |
| insp = sqlalchemy.inspect(engine) |
| self.assertTrue(insp.has_table("one_row")) |
| self.assertFalse(insp.has_table("this_table_does_not_exist")) |
| else: |
| self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) |
| self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) |
| |
| @with_engine_connection |
| def test_char_length(self, engine, connection): |
| one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) |
| |
| if sqlalchemy_version >= 1.4: |
| result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar() |
| else: |
| result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar() |
| |
| self.assertEqual(result, len('a string')) |