| from __future__ import absolute_import |
| from __future__ import unicode_literals |
| from builtins import str |
| from pyhive.sqlalchemy_hive import HiveDate |
| from pyhive.sqlalchemy_hive import HiveDecimal |
| from pyhive.sqlalchemy_hive import HiveTimestamp |
| from sqlalchemy.exc import NoSuchTableError, OperationalError |
| import pytest |
| from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase |
| from pyhive.tests.sqlalchemy_test_case import with_engine_connection |
| from sqlalchemy import types |
| from sqlalchemy.engine import create_engine |
| from sqlalchemy.schema import Column |
| from sqlalchemy.schema import MetaData |
| from sqlalchemy.schema import Table |
| from sqlalchemy.sql import text |
| import contextlib |
| import datetime |
| import decimal |
| import sqlalchemy.types |
| import unittest |
| import re |
| |
| sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) |
| |
| _ONE_ROW_COMPLEX_CONTENTS = [ |
| True, |
| 127, |
| 32767, |
| 2147483647, |
| 9223372036854775807, |
| 0.5, |
| 0.25, |
| 'a string', |
| datetime.datetime(1970, 1, 1), |
| b'123', |
| '[1,2]', |
| '{1:2,3:4}', |
| '{"a":1,"b":2}', |
| '{0:1}', |
| decimal.Decimal('0.1'), |
| ] |
| |
| |
| # [ |
| # ('boolean', 'boolean', ''), |
| # ('tinyint', 'tinyint', ''), |
| # ('smallint', 'smallint', ''), |
| # ('int', 'int', ''), |
| # ('bigint', 'bigint', ''), |
| # ('float', 'float', ''), |
| # ('double', 'double', ''), |
| # ('string', 'string', ''), |
| # ('timestamp', 'timestamp', ''), |
| # ('binary', 'binary', ''), |
| # ('array', 'array<int>', ''), |
| # ('map', 'map<int,int>', ''), |
| # ('struct', 'struct<a:int,b:int>', ''), |
| # ('union', 'uniontype<int,string>', ''), |
| # ('decimal', 'decimal(10,1)', '') |
| # ] |
| |
| |
| class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase): |
| def create_engine(self): |
| return create_engine('hive://localhost:10000/default') |
| |
| @with_engine_connection |
| def test_dotted_column_names(self, engine, connection): |
| """When Hive returns a dotted column name, both the non-dotted version should be available |
| as an attribute, and the dotted version should remain available as a key. |
| """ |
| row = connection.execute(text('SELECT * FROM one_row')).fetchone() |
| |
| if sqlalchemy_version >= 1.4: |
| row = row._mapping |
| |
| assert row.keys() == ['number_of_rows'] |
| assert 'number_of_rows' in row |
| assert row.number_of_rows == 1 |
| assert row['number_of_rows'] == 1 |
| assert getattr(row, 'one_row.number_of_rows') == 1 |
| assert row['one_row.number_of_rows'] == 1 |
| |
| @with_engine_connection |
| def test_dotted_column_names_raw(self, engine, connection): |
| """When Hive returns a dotted column name, and raw mode is on, nothing should be modified. |
| """ |
| row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone() |
| |
| if sqlalchemy_version >= 1.4: |
| row = row._mapping |
| |
| assert row.keys() == ['one_row.number_of_rows'] |
| assert 'number_of_rows' not in row |
| assert getattr(row, 'one_row.number_of_rows') == 1 |
| assert row['one_row.number_of_rows'] == 1 |
| |
| @with_engine_connection |
| def test_reflect_no_such_table(self, engine, connection): |
| """reflecttable should throw an exception on an invalid table""" |
| self.assertRaises( |
| NoSuchTableError, |
| lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) |
| self.assertRaises( |
| OperationalError, |
| lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) |
| |
| @with_engine_connection |
| def test_reflect_select(self, engine, connection): |
| """reflecttable should be able to fill in a table from the name""" |
| one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) |
| self.assertEqual(len(one_row_complex.c), 15) |
| self.assertIsInstance(one_row_complex.c.string, Column) |
| row = connection.execute(one_row_complex.select()).fetchone() |
| self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) |
| |
| # TODO some of these types could be filled in better |
| self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) |
| self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) |
| self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer) |
| self.assertIsInstance(one_row_complex.c.int.type, types.Integer) |
| self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger) |
| self.assertIsInstance(one_row_complex.c.float.type, types.Float) |
| self.assertIsInstance(one_row_complex.c.double.type, types.Float) |
| self.assertIsInstance(one_row_complex.c.string.type, types.String) |
| self.assertIsInstance(one_row_complex.c.timestamp.type, HiveTimestamp) |
| self.assertIsInstance(one_row_complex.c.binary.type, types.String) |
| self.assertIsInstance(one_row_complex.c.array.type, types.String) |
| self.assertIsInstance(one_row_complex.c.map.type, types.String) |
| self.assertIsInstance(one_row_complex.c.struct.type, types.String) |
| self.assertIsInstance(one_row_complex.c.union.type, types.String) |
| self.assertIsInstance(one_row_complex.c.decimal.type, HiveDecimal) |
| |
| @with_engine_connection |
| def test_type_map(self, engine, connection): |
| """sqlalchemy should use the dbapi_type_map to infer types from raw queries""" |
| row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone() |
| self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) |
| |
| @with_engine_connection |
| def test_reserved_words(self, engine, connection): |
| """Hive uses backticks""" |
| # Use keywords for the table/column name |
| fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String)) |
| query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine)) |
| self.assertIn('`select`', query) |
| self.assertIn('`map`', query) |
| self.assertNotIn('"select"', query) |
| self.assertNotIn('"map"', query) |
| |
| def test_switch_database(self): |
| engine = create_engine('hive://localhost:10000/pyhive_test_database') |
| try: |
| with contextlib.closing(engine.connect()) as connection: |
| self.assertIn( |
| ('dummy_table',), |
| connection.execute(text('SHOW TABLES')).fetchall() |
| ) |
| connection.execute(text('USE default')) |
| self.assertIn( |
| ('one_row',), |
| connection.execute(text('SHOW TABLES')).fetchall() |
| ) |
| finally: |
| engine.dispose() |
| |
| @with_engine_connection |
| def test_lots_of_types(self, engine, connection): |
| # Presto doesn't have raw CREATE TABLE support, so we ony test hive |
| # take type list from sqlalchemy.types |
| types = [ |
| 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT', |
| 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', |
| 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME', |
| 'String', 'Integer', 'SmallInteger', |
| 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', |
| 'Boolean', 'Unicode', 'UnicodeText', |
| ] |
| cols = [] |
| for i, t in enumerate(types): |
| cols.append(Column(str(i), getattr(sqlalchemy.types, t))) |
| cols.append(Column('hive_date', HiveDate)) |
| cols.append(Column('hive_decimal', HiveDecimal)) |
| cols.append(Column('hive_timestamp', HiveTimestamp)) |
| table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,) |
| table.drop(checkfirst=True, bind=connection) |
| table.create(bind=connection) |
| connection.execute(text('SET mapred.job.tracker=local')) |
| connection.execute(text('USE pyhive_test_database')) |
| big_number = 10 ** 10 - 1 |
| connection.execute(text(""" |
| INSERT OVERWRITE TABLE test_table |
| SELECT |
| 1, "a", "a", "a", "a", "a", 0.1, |
| 0.1, 0.1, 0, 0, "a", "a", |
| false, 1, 0, 0, |
| "a", 1, 1, |
| 0.1, 0.1, 0, 0, 0, "a", |
| false, "a", "a", |
| 0, :big_number, 123 + 2000 |
| FROM default.one_row |
| """), {"big_number": big_number}) |
| row = connection.execute(text("select * from test_table")).fetchone() |
| self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0)) |
| self.assertEqual(row.hive_decimal, decimal.Decimal(big_number)) |
| self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000)) |
| table.drop(bind=connection) |
| |
| @with_engine_connection |
| def test_insert_select(self, engine, connection): |
| one_row = Table('one_row', MetaData(), autoload_with=engine) |
| table = Table('insert_test', MetaData(schema='pyhive_test_database'), |
| Column('a', sqlalchemy.types.Integer)) |
| table.drop(checkfirst=True, bind=connection) |
| table.create(bind=connection) |
| connection.execute(text('SET mapred.job.tracker=local')) |
| # NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES |
| connection.execute(table.insert().from_select(['a'], one_row.select())) |
| |
| result = connection.execute(table.select()).fetchall() |
| expected = [(1,)] |
| self.assertEqual(result, expected) |
| |
| @with_engine_connection |
| def test_insert_values(self, engine, connection): |
| table = Table('insert_test', MetaData(schema='pyhive_test_database'), |
| Column('a', sqlalchemy.types.Integer),) |
| table.drop(checkfirst=True, bind=connection) |
| table.create(bind=connection) |
| connection.execute(table.insert().values([{'a': 1}, {'a': 2}])) |
| |
| result = connection.execute(table.select()).fetchall() |
| expected = [(1,), (2,)] |
| self.assertEqual(result, expected) |
| |
| @with_engine_connection |
| def test_supports_san_rowcount(self, engine, connection): |
| self.assertFalse(engine.dialect.supports_sane_rowcount_returning) |