blob: a29a661a23fdc3c9e9ec38e3f4a9dc45c86d9fbe [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas
import pyarrow
import pytest
from pandas.testing import assert_frame_equal
from adbc_driver_manager import dbapi
@pytest.fixture
def sqlite():
"""Dynamically load the SQLite driver."""
with dbapi.connect(driver="adbc_driver_sqlite") as conn:
yield conn
def test_type_objects():
assert dbapi.NUMBER == pyarrow.int64()
assert pyarrow.int64() == dbapi.NUMBER
assert dbapi.STRING == pyarrow.string()
assert pyarrow.string() == dbapi.STRING
assert dbapi.STRING != dbapi.NUMBER
assert dbapi.NUMBER != dbapi.DATETIME
assert dbapi.NUMBER == dbapi.ROWID
@pytest.mark.sqlite
def test_attrs(sqlite):
assert sqlite.Warning == dbapi.Warning
assert sqlite.Error == dbapi.Error
assert sqlite.InterfaceError == dbapi.InterfaceError
assert sqlite.DatabaseError == dbapi.DatabaseError
assert sqlite.DataError == dbapi.DataError
assert sqlite.OperationalError == dbapi.OperationalError
assert sqlite.IntegrityError == dbapi.IntegrityError
assert sqlite.InternalError == dbapi.InternalError
assert sqlite.ProgrammingError == dbapi.ProgrammingError
assert sqlite.NotSupportedError == dbapi.NotSupportedError
with sqlite.cursor() as cur:
assert cur.arraysize == 1
assert cur.connection is sqlite
assert cur.description is None
assert cur.rowcount == -1
@pytest.mark.sqlite
def test_info(sqlite):
info = sqlite.adbc_get_info()
assert set(info.keys()) == {
"driver_arrow_version",
"driver_name",
"driver_version",
"vendor_name",
"vendor_version",
}
assert info["driver_name"] == "ADBC SQLite Driver"
assert info["vendor_name"] == "SQLite"
@pytest.mark.sqlite
def test_get_underlying(sqlite):
assert sqlite.adbc_database
assert sqlite.adbc_connection
with sqlite.cursor() as cur:
assert cur.adbc_statement
@pytest.mark.sqlite
def test_clone(sqlite):
with sqlite.adbc_clone() as sqlite2:
with sqlite2.cursor() as cur:
cur.execute("CREATE TABLE temporary (ints)")
cur.execute("INSERT INTO temporary VALUES (1)")
sqlite2.commit()
with sqlite.cursor() as cur:
cur.execute("SELECT * FROM temporary")
assert cur.fetchone() == (1,)
@pytest.mark.sqlite
def test_get_objects(sqlite):
with sqlite.cursor() as cur:
cur.execute("CREATE TABLE temporary (ints)")
cur.execute("INSERT INTO temporary VALUES (1)")
metadata = (
sqlite.adbc_get_objects(table_name_filter="temporary").read_all().to_pylist()
)
assert len(metadata) == 1
assert metadata[0]["catalog_name"] == "main"
schemas = metadata[0]["catalog_db_schemas"]
assert len(schemas) == 1
assert schemas[0]["db_schema_name"] is None
tables = schemas[0]["db_schema_tables"]
assert len(tables) == 1
assert tables[0]["table_name"] == "temporary"
assert tables[0]["table_type"] == "table"
assert tables[0]["table_columns"][0]["column_name"] == "ints"
assert tables[0]["table_columns"][0]["ordinal_position"] == 1
assert tables[0]["table_constraints"] == []
@pytest.mark.sqlite
def test_get_table_schema(sqlite):
with sqlite.cursor() as cur:
cur.execute("CREATE TABLE temporary (ints)")
cur.execute("INSERT INTO temporary VALUES (1)")
assert sqlite.adbc_get_table_schema("temporary") == pyarrow.schema(
[("ints", pyarrow.int64())]
)
@pytest.mark.sqlite
def test_get_table_types(sqlite):
assert sqlite.adbc_get_table_types() == ["table", "view"]
@pytest.mark.parametrize(
"data",
[
lambda: pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"]),
lambda: pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]),
lambda: pyarrow.table(
[[1, 2], ["foo", ""]], names=["ints", "strs"]
).to_reader(),
],
)
@pytest.mark.sqlite
def test_ingest(data, sqlite):
with sqlite.cursor() as cur:
cur.adbc_ingest("bulk_ingest", data())
with pytest.raises(dbapi.Error):
cur.adbc_ingest("bulk_ingest", data())
cur.adbc_ingest("bulk_ingest", data(), mode="append")
with pytest.raises(dbapi.Error):
cur.adbc_ingest("nonexistent", data(), mode="append")
with pytest.raises(ValueError):
cur.adbc_ingest("bulk_ingest", data(), mode="invalid")
with sqlite.cursor() as cur:
cur.execute("SELECT * FROM bulk_ingest")
assert cur.fetchone() == (1, "foo")
assert cur.fetchone() == (2, "")
assert cur.fetchone() == (1, "foo")
assert cur.fetchone() == (2, "")
@pytest.mark.sqlite
def test_partitions(sqlite):
with pytest.raises(dbapi.NotSupportedError):
with sqlite.cursor() as cur:
cur.adbc_execute_partitions("SELECT 1")
@pytest.mark.sqlite
def test_query_fetch_py(sqlite):
with sqlite.cursor() as cur:
cur.execute('SELECT 1, "foo", 2.0')
assert cur.description == [
("1", dbapi.NUMBER, None, None, None, None, None),
('"foo"', dbapi.STRING, None, None, None, None, None),
("2.0", dbapi.NUMBER, None, None, None, None, None),
]
assert cur.rownumber == 0
assert cur.fetchone() == (1, "foo", 2.0)
assert cur.rownumber == 1
assert cur.fetchone() is None
cur.execute('SELECT 1, "foo", 2.0')
assert cur.fetchmany() == [(1, "foo", 2.0)]
assert cur.fetchmany() == []
cur.execute('SELECT 1, "foo", 2.0')
assert cur.fetchall() == [(1, "foo", 2.0)]
assert cur.fetchall() == []
cur.execute('SELECT 1, "foo", 2.0')
assert list(cur) == [(1, "foo", 2.0)]
@pytest.mark.sqlite
def test_query_fetch_arrow(sqlite):
with sqlite.cursor() as cur:
cur.execute('SELECT 1, "foo", 2.0')
assert cur.fetch_arrow_table() == pyarrow.table(
{
"1": [1],
'"foo"': ["foo"],
"2.0": [2.0],
}
)
@pytest.mark.sqlite
def test_query_fetch_df(sqlite):
with sqlite.cursor() as cur:
cur.execute('SELECT 1, "foo", 2.0')
assert_frame_equal(
cur.fetch_df(),
pandas.DataFrame(
{
"1": [1],
'"foo"': ["foo"],
"2.0": [2.0],
}
),
)
@pytest.mark.sqlite
@pytest.mark.parametrize(
"parameters",
[
(1.0, 2),
pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
pyarrow.table([[1.0], [2]], names=["float", "int"]),
],
)
def test_execute_parameters(sqlite, parameters):
with sqlite.cursor() as cur:
cur.execute("SELECT ? + 1, ?", parameters)
assert cur.fetchall() == [(2.0, 2)]
@pytest.mark.sqlite
@pytest.mark.parametrize(
"parameters",
[
[(1, "a"), (3, None)],
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0],
],
)
def test_executemany_parameters(sqlite, parameters):
with sqlite.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS executemany")
cur.execute("CREATE TABLE executemany (int, str)")
cur.executemany("INSERT INTO executemany VALUES (? * 2, ?)", parameters)
cur.execute("SELECT * FROM executemany ORDER BY int ASC")
assert cur.fetchall() == [(2, "a"), (6, None)]
@pytest.mark.sqlite
def test_query_substrait(sqlite):
with sqlite.cursor() as cur:
with pytest.raises(dbapi.NotSupportedError):
cur.execute(b"Substrait plan")
@pytest.mark.sqlite
def test_executemany(sqlite):
with sqlite.cursor() as cur:
cur.execute("CREATE TABLE foo (a, b)")
cur.executemany(
"INSERT INTO foo VALUES (?, ?)",
[
(1, 2),
(3, 4),
(5, 6),
],
)
cur.execute("SELECT COUNT(*) FROM foo")
assert cur.fetchone() == (3,)
cur.execute("SELECT * FROM foo ORDER BY a ASC")
assert cur.rownumber == 0
assert next(cur) == (1, 2)
assert cur.rownumber == 1
assert next(cur) == (3, 4)
assert cur.rownumber == 2
assert next(cur) == (5, 6)
@pytest.mark.sqlite
def test_fetch_record_batch(sqlite):
dataset = [
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
]
with sqlite.cursor() as cur:
cur.execute("CREATE TABLE foo (a, b)")
cur.executemany(
"INSERT INTO foo VALUES (?, ?)",
dataset,
)
cur.execute("SELECT * FROM foo")
rbr = cur.fetch_record_batch()
assert rbr.read_pandas().values.tolist() == dataset
@pytest.mark.sqlite
def test_fetch_empty(sqlite):
with sqlite.cursor() as cur:
cur.execute("CREATE TABLE foo (bar)")
cur.execute("SELECT * FROM foo")
assert cur.fetchall() == []
@pytest.mark.sqlite
def test_prepare(sqlite):
with sqlite.cursor() as cur:
schema = cur.adbc_prepare("SELECT 1")
assert schema == pyarrow.schema([])
schema = cur.adbc_prepare("SELECT 1 + ?")
assert schema == pyarrow.schema([("0", "null")])
cur.execute("SELECT 1 + ?", (1,))
assert cur.fetchone() == (2,)
@pytest.mark.sqlite
def test_close_warning(sqlite):
with pytest.warns(
ResourceWarning,
match=r"A adbc_driver_manager.dbapi.Cursor was not explicitly close\(\)d",
):
cur = sqlite.cursor()
del cur
with pytest.warns(
ResourceWarning,
match=r"A adbc_driver_manager.dbapi.Connection was not explicitly close\(\)d",
):
conn = dbapi.connect(driver="adbc_driver_sqlite")
del conn