blob: dd60a7cdf1802504e3ccb7ce08b1d7e4dbe70e54 [file] [log] [blame]
"""Presto integration tests.
These rely on having a Presto+Hadoop cluster set up.
They also require a tables created by make_test_tables.sh.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import contextlib
import os
from decimal import Decimal
import requests
import pytest
from pyhive import exc
from pyhive import presto
from pyhive.tests.dbapi_test_case import DBAPITestCase
from pyhive.tests.dbapi_test_case import with_cursor
import mock
import unittest
import datetime
_HOST = 'localhost'
_PORT = '8080'
class TestPresto(unittest.TestCase, DBAPITestCase):
__test__ = True
def connect(self):
return presto.connect(host=_HOST, port=_PORT, source=self.id())
def test_bad_protocol(self):
self.assertRaisesRegexp(ValueError, 'Protocol must be',
lambda: presto.connect('localhost', protocol='nonsense').cursor())
def test_escape_args(self):
escaper = presto.PrestoParamEscaper()
self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)),
("date '2020-04-17'",))
self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)),
("timestamp '2020-04-17 12:00:00.123'",))
@with_cursor
def test_description(self, cursor):
cursor.execute('SELECT 1 AS foobar FROM one_row')
self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)])
self.assertIsNotNone(cursor.last_query_id)
@with_cursor
def test_complex(self, cursor):
cursor.execute('SELECT * FROM one_row_complex')
# TODO Presto drops the union field
if os.environ.get('PRESTO') == '0.147':
tinyint_type = 'integer'
smallint_type = 'integer'
float_type = 'double'
else:
# some later version made these map to more specific types
tinyint_type = 'tinyint'
smallint_type = 'smallint'
float_type = 'real'
self.assertEqual(cursor.description, [
('boolean', 'boolean', None, None, None, None, True),
('tinyint', tinyint_type, None, None, None, None, True),
('smallint', smallint_type, None, None, None, None, True),
('int', 'integer', None, None, None, None, True),
('bigint', 'bigint', None, None, None, None, True),
('float', float_type, None, None, None, None, True),
('double', 'double', None, None, None, None, True),
('string', 'varchar', None, None, None, None, True),
('timestamp', 'timestamp', None, None, None, None, True),
('binary', 'varbinary', None, None, None, None, True),
('array', 'array(integer)', None, None, None, None, True),
('map', 'map(integer,integer)', None, None, None, None, True),
('struct', 'row(a integer,b integer)', None, None, None, None, True),
# ('union', 'varchar', None, None, None, None, True),
('decimal', 'decimal(10,1)', None, None, None, None, True),
])
rows = cursor.fetchall()
expected = [(
True,
127,
32767,
2147483647,
9223372036854775807,
0.5,
0.25,
'a string',
'1970-01-01 00:00:00.000',
b'123',
[1, 2],
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
Decimal('0.1'),
)]
self.assertEqual(rows, expected)
# catch unicode/str
self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
@with_cursor
def test_cancel(self, cursor):
cursor.execute(
"SELECT a.a * rand(), b.a * rand()"
"FROM many_rows a "
"CROSS JOIN many_rows b "
)
self.assertIn(cursor.poll()['stats']['state'], (
'STARTING', 'PLANNING', 'RUNNING', 'WAITING_FOR_RESOURCES', 'QUEUED'))
cursor.cancel()
self.assertIsNotNone(cursor.last_query_id)
self.assertIsNone(cursor.poll())
def test_noops(self):
"""The DB-API specification requires that certain actions exist, even though they might not
be applicable."""
# Wohoo inflating coverage stats!
connection = self.connect()
cursor = connection.cursor()
self.assertEqual(cursor.rowcount, -1)
cursor.setinputsizes([])
cursor.setoutputsize(1, 'blah')
self.assertIsNone(cursor.last_query_id)
connection.commit()
@mock.patch('requests.post')
def test_non_200(self, post):
cursor = self.connect().cursor()
post.return_value.status_code = 404
self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables'))
@with_cursor
def test_poll(self, cursor):
self.assertRaises(presto.ProgrammingError, cursor.poll)
cursor.execute('SELECT * FROM one_row')
while True:
status = cursor.poll()
if status is None:
break
self.assertIn('stats', status)
def fail(*args, **kwargs):
self.fail("Should not need requests.get after done polling") # pragma: no cover
with mock.patch('requests.get', fail):
self.assertEqual(cursor.fetchall(), [(1,)])
@with_cursor
def test_set_session(self, cursor):
id = None
self.assertIsNone(cursor.last_query_id)
cursor.execute("SET SESSION query_max_run_time = '1234m'")
self.assertIsNotNone(cursor.last_query_id)
id = cursor.last_query_id
cursor.fetchall()
self.assertEqual(id, cursor.last_query_id)
cursor.execute('SHOW SESSION')
self.assertIsNotNone(cursor.last_query_id)
self.assertNotEqual(id, cursor.last_query_id)
id = cursor.last_query_id
rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
self.assertEqual(len(rows), 1)
session_prop = rows[0]
self.assertEqual(session_prop[1], '1234m')
self.assertEqual(id, cursor.last_query_id)
cursor.execute('RESET SESSION query_max_run_time')
self.assertIsNotNone(cursor.last_query_id)
self.assertNotEqual(id, cursor.last_query_id)
id = cursor.last_query_id
cursor.fetchall()
self.assertEqual(id, cursor.last_query_id)
cursor.execute('SHOW SESSION')
self.assertIsNotNone(cursor.last_query_id)
self.assertNotEqual(id, cursor.last_query_id)
id = cursor.last_query_id
rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
self.assertEqual(len(rows), 1)
session_prop = rows[0]
self.assertNotEqual(session_prop[1], '1234m')
self.assertEqual(id, cursor.last_query_id)
def test_set_session_in_constructor(self):
conn = presto.connect(
host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'}
)
with contextlib.closing(conn):
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute('SHOW SESSION')
rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
assert len(rows) == 1
session_prop = rows[0]
assert session_prop[1] == '1234m'
cursor.execute('RESET SESSION query_max_run_time')
cursor.fetchall()
cursor.execute('SHOW SESSION')
rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
assert len(rows) == 1
session_prop = rows[0]
assert session_prop[1] != '1234m'
def test_invalid_protocol_config(self):
"""protocol should be https when passing password"""
self.assertRaisesRegexp(
ValueError, 'Protocol.*https.*password', lambda: presto.connect(
host=_HOST, username='user', password='secret', protocol='http').cursor()
)
def test_invalid_password_and_kwargs(self):
"""password and requests_kwargs authentication are incompatible"""
self.assertRaisesRegexp(
ValueError, 'Cannot use both', lambda: presto.connect(
host=_HOST, username='user', password='secret', protocol='https',
requests_kwargs={'auth': requests.auth.HTTPBasicAuth('user', 'secret')}
).cursor()
)
def test_invalid_kwargs(self):
"""some kwargs are reserved"""
self.assertRaisesRegexp(
ValueError, 'Cannot override', lambda: presto.connect(
host=_HOST, username='user', requests_kwargs={'url': 'test'}
).cursor()
)
@pytest.mark.skip(reason='This test requires a proxy server running on localhost:9999')
def test_requests_kwargs(self):
connection = presto.connect(
host=_HOST, port=_PORT, source=self.id(),
requests_kwargs={'proxies': {'http': 'localhost:9999'}},
)
cursor = connection.cursor()
self.assertRaises(requests.exceptions.ProxyError,
lambda: cursor.execute('SELECT * FROM one_row'))
def test_requests_session(self):
with requests.Session() as session:
connection = presto.connect(
host=_HOST, port=_PORT, source=self.id(), requests_session=session
)
cursor = connection.cursor()
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])