blob: eda2d10928ba9539b0dd6e199c7adfcc85ed0eb6 [file] [log] [blame]
# encoding: utf-8
"""Shared DB-API test cases"""
from __future__ import absolute_import
from __future__ import unicode_literals
from builtins import object
from builtins import range
from future.utils import with_metaclass
from pyhive import exc
import abc
import contextlib
import functools
def with_cursor(fn):
"""Pass a cursor to the given function and handle cleanup.
The cursor is taken from ``self.connect()``.
"""
@functools.wraps(fn)
def wrapped_fn(self, *args, **kwargs):
with contextlib.closing(self.connect()) as connection:
with contextlib.closing(connection.cursor()) as cursor:
fn(self, cursor, *args, **kwargs)
return wrapped_fn
class DBAPITestCase(with_metaclass(abc.ABCMeta, object)):
@abc.abstractmethod
def connect(self):
raise NotImplementedError # pragma: no cover
@with_cursor
def test_fetchone(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.rownumber, 0)
self.assertEqual(cursor.fetchone(), (1,))
self.assertEqual(cursor.rownumber, 1)
self.assertIsNone(cursor.fetchone())
@with_cursor
def test_fetchall(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])
cursor.execute('SELECT a FROM many_rows ORDER BY a')
self.assertEqual(cursor.fetchall(), [(i,) for i in range(10000)])
@with_cursor
def test_null_param(self, cursor):
cursor.execute('SELECT %s FROM one_row', (None,))
self.assertEqual(cursor.fetchall(), [(None,)])
@with_cursor
def test_iterator(self, cursor):
cursor.execute('SELECT * FROM one_row')
self.assertEqual(list(cursor), [(1,)])
self.assertRaises(StopIteration, cursor.__next__)
@with_cursor
def test_description_initial(self, cursor):
self.assertIsNone(cursor.description)
@with_cursor
def test_description_failed(self, cursor):
try:
cursor.execute('blah_blah')
self.assertIsNone(cursor.description)
except exc.DatabaseError:
pass
@with_cursor
def test_bad_query(self, cursor):
def run():
cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist')
cursor.fetchone()
self.assertRaises(exc.DatabaseError, run)
@with_cursor
def test_concurrent_execution(self, cursor):
cursor.execute('SELECT * FROM one_row')
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])
@with_cursor
def test_executemany(self, cursor):
for length in 1, 2:
cursor.executemany(
'SELECT %(x)d FROM one_row',
[{'x': i} for i in range(1, length + 1)]
)
self.assertEqual(cursor.fetchall(), [(length,)])
@with_cursor
def test_executemany_none(self, cursor):
cursor.executemany('should_never_get_used', [])
self.assertIsNone(cursor.description)
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
@with_cursor
def test_fetchone_no_data(self, cursor):
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
@with_cursor
def test_fetchmany(self, cursor):
cursor.execute('SELECT * FROM many_rows LIMIT 15')
self.assertEqual(cursor.fetchmany(0), [])
self.assertEqual(len(cursor.fetchmany(10)), 10)
self.assertEqual(len(cursor.fetchmany(10)), 5)
@with_cursor
def test_arraysize(self, cursor):
cursor.arraysize = 5
cursor.execute('SELECT * FROM many_rows LIMIT 20')
self.assertEqual(len(cursor.fetchmany()), 5)
@with_cursor
def test_polling_loop(self, cursor):
"""Try to trigger the polling logic in fetchone()"""
cursor._poll_interval = 0
cursor.execute('SELECT COUNT(*) FROM many_rows')
self.assertEqual(cursor.fetchone(), (10000,))
@with_cursor
def test_no_params(self, cursor):
cursor.execute("SELECT '%(x)s' FROM one_row")
self.assertEqual(cursor.fetchall(), [('%(x)s',)])
def test_escape(self):
"""Verify that funny characters can be escaped as strings and SELECTed back"""
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t '''
self.run_escape_case(bad_str)
@with_cursor
def run_escape_case(self, cursor, bad_str):
cursor.execute(
'SELECT %d, %s FROM one_row',
(1, bad_str)
)
self.assertEqual(cursor.fetchall(), [(1, bad_str,)])
cursor.execute(
'SELECT %(a)d, %(b)s FROM one_row',
{'a': 1, 'b': bad_str}
)
self.assertEqual(cursor.fetchall(), [(1, bad_str)])
@with_cursor
def test_invalid_params(self, cursor):
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi'))
self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [object]))
def test_open_close(self):
with contextlib.closing(self.connect()):
pass
with contextlib.closing(self.connect()) as connection:
with contextlib.closing(connection.cursor()):
pass
@with_cursor
def test_unicode(self, cursor):
unicode_str = "王兢"
cursor.execute(
'SELECT %s FROM one_row',
(unicode_str,)
)
self.assertEqual(cursor.fetchall(), [(unicode_str,)])
@with_cursor
def test_null(self, cursor):
cursor.execute('SELECT null FROM many_rows')
self.assertEqual(cursor.fetchall(), [(None,)] * 10000)
cursor.execute('SELECT IF(a % 11 = 0, null, a) FROM many_rows')
self.assertEqual(cursor.fetchall(), [(None if a % 11 == 0 else a,) for a in range(10000)])
@with_cursor
def test_sql_where_in(self, cursor):
cursor.execute('SELECT * FROM many_rows where a in %s', ([1, 2, 3],))
self.assertEqual(len(cursor.fetchall()), 3)
cursor.execute('SELECT * FROM many_rows where b in %s limit 10',
(['blah'],))
self.assertEqual(len(cursor.fetchall()), 10)