| # encoding: utf-8 |
| """Shared DB-API test cases""" |
| |
| from __future__ import absolute_import |
| from __future__ import unicode_literals |
| from builtins import object |
| from builtins import range |
| from future.utils import with_metaclass |
| from pyhive import exc |
| import abc |
| import contextlib |
| import functools |
| |
| |
| def with_cursor(fn): |
| """Pass a cursor to the given function and handle cleanup. |
| |
| The cursor is taken from ``self.connect()``. |
| """ |
| @functools.wraps(fn) |
| def wrapped_fn(self, *args, **kwargs): |
| with contextlib.closing(self.connect()) as connection: |
| with contextlib.closing(connection.cursor()) as cursor: |
| fn(self, cursor, *args, **kwargs) |
| return wrapped_fn |
| |
| |
| class DBAPITestCase(with_metaclass(abc.ABCMeta, object)): |
| @abc.abstractmethod |
| def connect(self): |
| raise NotImplementedError # pragma: no cover |
| |
| @with_cursor |
| def test_fetchone(self, cursor): |
| cursor.execute('SELECT * FROM one_row') |
| self.assertEqual(cursor.rownumber, 0) |
| self.assertEqual(cursor.fetchone(), (1,)) |
| self.assertEqual(cursor.rownumber, 1) |
| self.assertIsNone(cursor.fetchone()) |
| |
| @with_cursor |
| def test_fetchall(self, cursor): |
| cursor.execute('SELECT * FROM one_row') |
| self.assertEqual(cursor.fetchall(), [(1,)]) |
| cursor.execute('SELECT a FROM many_rows ORDER BY a') |
| self.assertEqual(cursor.fetchall(), [(i,) for i in range(10000)]) |
| |
| @with_cursor |
| def test_null_param(self, cursor): |
| cursor.execute('SELECT %s FROM one_row', (None,)) |
| self.assertEqual(cursor.fetchall(), [(None,)]) |
| |
| @with_cursor |
| def test_iterator(self, cursor): |
| cursor.execute('SELECT * FROM one_row') |
| self.assertEqual(list(cursor), [(1,)]) |
| self.assertRaises(StopIteration, cursor.__next__) |
| |
| @with_cursor |
| def test_description_initial(self, cursor): |
| self.assertIsNone(cursor.description) |
| |
| @with_cursor |
| def test_description_failed(self, cursor): |
| try: |
| cursor.execute('blah_blah') |
| self.assertIsNone(cursor.description) |
| except exc.DatabaseError: |
| pass |
| |
| @with_cursor |
| def test_bad_query(self, cursor): |
| def run(): |
| cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist') |
| cursor.fetchone() |
| self.assertRaises(exc.DatabaseError, run) |
| |
| @with_cursor |
| def test_concurrent_execution(self, cursor): |
| cursor.execute('SELECT * FROM one_row') |
| cursor.execute('SELECT * FROM one_row') |
| self.assertEqual(cursor.fetchall(), [(1,)]) |
| |
| @with_cursor |
| def test_executemany(self, cursor): |
| for length in 1, 2: |
| cursor.executemany( |
| 'SELECT %(x)d FROM one_row', |
| [{'x': i} for i in range(1, length + 1)] |
| ) |
| self.assertEqual(cursor.fetchall(), [(length,)]) |
| |
| @with_cursor |
| def test_executemany_none(self, cursor): |
| cursor.executemany('should_never_get_used', []) |
| self.assertIsNone(cursor.description) |
| self.assertRaises(exc.ProgrammingError, cursor.fetchone) |
| |
| @with_cursor |
| def test_fetchone_no_data(self, cursor): |
| self.assertRaises(exc.ProgrammingError, cursor.fetchone) |
| |
| @with_cursor |
| def test_fetchmany(self, cursor): |
| cursor.execute('SELECT * FROM many_rows LIMIT 15') |
| self.assertEqual(cursor.fetchmany(0), []) |
| self.assertEqual(len(cursor.fetchmany(10)), 10) |
| self.assertEqual(len(cursor.fetchmany(10)), 5) |
| |
| @with_cursor |
| def test_arraysize(self, cursor): |
| cursor.arraysize = 5 |
| cursor.execute('SELECT * FROM many_rows LIMIT 20') |
| self.assertEqual(len(cursor.fetchmany()), 5) |
| |
| @with_cursor |
| def test_polling_loop(self, cursor): |
| """Try to trigger the polling logic in fetchone()""" |
| cursor._poll_interval = 0 |
| cursor.execute('SELECT COUNT(*) FROM many_rows') |
| self.assertEqual(cursor.fetchone(), (10000,)) |
| |
| @with_cursor |
| def test_no_params(self, cursor): |
| cursor.execute("SELECT '%(x)s' FROM one_row") |
| self.assertEqual(cursor.fetchall(), [('%(x)s',)]) |
| |
| def test_escape(self): |
| """Verify that funny characters can be escaped as strings and SELECTed back""" |
| bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t ''' |
| self.run_escape_case(bad_str) |
| |
| @with_cursor |
| def run_escape_case(self, cursor, bad_str): |
| cursor.execute( |
| 'SELECT %d, %s FROM one_row', |
| (1, bad_str) |
| ) |
| self.assertEqual(cursor.fetchall(), [(1, bad_str,)]) |
| cursor.execute( |
| 'SELECT %(a)d, %(b)s FROM one_row', |
| {'a': 1, 'b': bad_str} |
| ) |
| self.assertEqual(cursor.fetchall(), [(1, bad_str)]) |
| |
| @with_cursor |
| def test_invalid_params(self, cursor): |
| self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi')) |
| self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [object])) |
| |
| def test_open_close(self): |
| with contextlib.closing(self.connect()): |
| pass |
| with contextlib.closing(self.connect()) as connection: |
| with contextlib.closing(connection.cursor()): |
| pass |
| |
| @with_cursor |
| def test_unicode(self, cursor): |
| unicode_str = "王兢" |
| cursor.execute( |
| 'SELECT %s FROM one_row', |
| (unicode_str,) |
| ) |
| self.assertEqual(cursor.fetchall(), [(unicode_str,)]) |
| |
| @with_cursor |
| def test_null(self, cursor): |
| cursor.execute('SELECT null FROM many_rows') |
| self.assertEqual(cursor.fetchall(), [(None,)] * 10000) |
| cursor.execute('SELECT IF(a % 11 = 0, null, a) FROM many_rows') |
| self.assertEqual(cursor.fetchall(), [(None if a % 11 == 0 else a,) for a in range(10000)]) |
| |
| @with_cursor |
| def test_sql_where_in(self, cursor): |
| cursor.execute('SELECT * FROM many_rows where a in %s', ([1, 2, 3],)) |
| self.assertEqual(len(cursor.fetchall()), 3) |
| cursor.execute('SELECT * FROM many_rows where b in %s limit 10', |
| (['blah'],)) |
| self.assertEqual(len(cursor.fetchall()), 10) |