blob: fe85e4f22a969b4660700711b4b609d930b488ae [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import re
import secrets
import threading
import time
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
import pytest
def assert_detail(e):
# Check that the expected error details are present
found = set()
for _, detail in e.details:
anyproto = any_pb2.Any()
anyproto.ParseFromString(detail)
string = wrappers_pb2.StringValue()
anyproto.Unpack(string)
found.add(string.value)
assert found == {"detail1", "detail2"}
def test_query_cancel(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("forever")
cur.adbc_cancel()
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape(
"CANCELLED: [FlightSQL] context canceled"
" (Canceled; DoGet: endpoint 0: []). Vendor code: 1"
),
):
cur.fetchone()
def test_query_cancel_async(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("forever")
def _cancel():
time.sleep(2)
cur.adbc_cancel()
t = threading.Thread(target=_cancel, daemon=True)
t.start()
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape(
"CANCELLED: [FlightSQL] context canceled"
" (Canceled; DoGet: endpoint 0: []). Vendor code: 1"
),
):
cur.fetchone()
def test_query_error_fetch(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
# Match more exactly to make sure there's not unexpected junk in the string
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error (DoGet)"
" (InvalidArgument; DoGet: endpoint 0: []). Vendor code: 3"
),
):
cur.fetch_arrow_table()
cur.execute("error_do_get_detail")
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error (DoGet)"
" (InvalidArgument; DoGet: endpoint 0: [])"
),
) as excval:
cur.fetch_arrow_table()
assert_detail(excval.value)
def test_query_error_vendor_code(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error (DoGet)"
" (InvalidArgument; DoGet: endpoint 0: []). Vendor code: 3"
),
) as excval:
cur.fetch_arrow_table()
assert excval.value.vendor_code == 3
def test_query_error_stream(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get_stream")
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected stream error (DoGet)"
" (InvalidArgument; DoGet: endpoint 0: []). Vendor code: 3"
),
):
cur.fetchone()
cur.fetchone()
cur.execute("error_do_get_stream_detail")
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected stream error (DoGet)"
" (InvalidArgument; DoGet: endpoint 0: [])"
),
) as excval:
cur.fetchone()
cur.fetchone()
assert_detail(excval.value)
def test_query_error_bind(test_dbapi):
with test_dbapi.cursor() as cur:
cur.adbc_prepare("error_do_put")
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape(
"UNKNOWN: [FlightSQL] expected error (DoPut) (Unknown; ExecuteQuery)"
),
):
cur.execute("error_do_put", parameters=(1, "a"))
cur.adbc_prepare("error_do_put_detail")
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape(
"UNKNOWN: [FlightSQL] expected error (DoPut)" " (Unknown; ExecuteQuery)"
),
) as excval:
cur.execute("error_do_put_detail", parameters=(1, "a"))
assert_detail(excval.value)
def test_query_error_create_prepared_statement(test_dbapi):
with test_dbapi.cursor() as cur:
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error (DoAction)"
" (InvalidArgument; Prepare). Vendor code: 3"
),
):
cur.adbc_prepare("error_create_prepared_statement")
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error (DoAction)"
" (InvalidArgument; Prepare)"
),
) as excval:
cur.adbc_prepare("error_create_prepared_statement_detail")
assert_detail(excval.value)
def test_query_error_getflightinfo(test_dbapi):
with test_dbapi.cursor() as cur:
with pytest.raises(
Exception,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
" (InvalidArgument; ExecuteQuery). Vendor code: 3"
),
):
cur.execute("error_get_flight_info")
with pytest.raises(
Exception,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error"
" (GetFlightInfo) (InvalidArgument; ExecuteQuery)"
),
) as excval:
cur.execute("error_get_flight_info_detail")
assert_detail(excval.value)
cur.adbc_prepare("error_get_flight_info")
with pytest.raises(
Exception,
match=re.escape(
"INVALID_ARGUMENT: [FlightSQL] expected error (GetFlightInfo)"
" (InvalidArgument; ExecutePartitions). Vendor code: 3"
),
):
cur.adbc_execute_partitions("error_get_flight_info")
def test_stateless_prepared_statement(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_prepare("stateless_prepared_statement")
cur.execute("stateless_prepared_statement", parameters=[(1,)])
def test_header_propagation(test_dbapi) -> None:
header = "x-trace"
option = f"adbc.flight.sql.rpc.call_header.{header}"
@contextlib.contextmanager
def _trace(value):
test_dbapi.adbc_connection.set_options(**{option: value})
yield
test_dbapi.adbc_connection.set_options(**{option: ""})
getobjects = secrets.token_hex(16)
with _trace(getobjects):
with test_dbapi.adbc_get_objects():
pass
stmt = secrets.token_hex(16)
with _trace(stmt):
with test_dbapi.cursor() as cur:
cur.execute("foo")
cur.fetchall()
prepared = secrets.token_hex(16)
with _trace(prepared):
with test_dbapi.cursor() as cur:
cur.adbc_prepare("stateless_prepared_statement")
cur.execute("stateless_prepared_statement", parameters=[(1,)])
txn = secrets.token_hex(16)
with _trace(txn):
test_dbapi.adbc_connection.set_autocommit(False)
test_dbapi.adbc_connection.set_autocommit(True)
sess = secrets.token_hex(16)
with _trace(sess):
test_dbapi.adbc_connection.get_option("adbc.flight.sql.session.options")
test_dbapi.adbc_connection.set_options(
**{
"adbc.flight.sql.session.option.foo": 2,
}
)
with test_dbapi.cursor() as cur:
cur.execute("recorded_headers")
headers = [x for x in cur.fetchall() if x[1] == header]
for method in [
"GetFlightInfoCatalogs",
"DoGetCatalogs",
"GetFlightInfoDBSchemas",
"DoGetDBSchemas",
"GetFlightInfoTables",
"DoGetTables",
]:
assert (method, header, getobjects) in headers
for method in [
"CreatePreparedStatement",
"ClosePreparedStatement",
"GetFlightInfoPreparedStatement",
"DoGetPreparedStatement",
]:
assert (method, header, stmt) in headers
for method in [
"CreatePreparedStatement",
"ClosePreparedStatement",
"GetFlightInfoPreparedStatement",
"DoGetPreparedStatement",
"DoPutPreparedStatementQuery",
]:
assert (method, header, prepared) in headers
for method in [
"BeginTransaction",
"EndTransaction",
]:
assert (method, header, txn) in headers
for method in [
"GetSessionOptions",
"SetSessionOptions",
]:
assert (method, header, sess) in headers