| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import pathlib |
| |
| import pandas |
| import pyarrow |
| import pyarrow.dataset |
| import pytest |
| from pandas.testing import assert_frame_equal |
| |
| from adbc_driver_manager import dbapi |
| |
| |
| def test_type_objects(): |
| assert dbapi.NUMBER == pyarrow.int64() |
| assert pyarrow.int64() == dbapi.NUMBER |
| |
| assert dbapi.STRING == pyarrow.string() |
| assert pyarrow.string() == dbapi.STRING |
| |
| assert dbapi.STRING != dbapi.NUMBER |
| assert dbapi.NUMBER != dbapi.DATETIME |
| assert dbapi.NUMBER == dbapi.ROWID |
| |
| |
| @pytest.mark.sqlite |
| def test_attrs(sqlite): |
| assert sqlite.Warning == dbapi.Warning |
| assert sqlite.Error == dbapi.Error |
| assert sqlite.InterfaceError == dbapi.InterfaceError |
| assert sqlite.DatabaseError == dbapi.DatabaseError |
| assert sqlite.DataError == dbapi.DataError |
| assert sqlite.OperationalError == dbapi.OperationalError |
| assert sqlite.IntegrityError == dbapi.IntegrityError |
| assert sqlite.InternalError == dbapi.InternalError |
| assert sqlite.ProgrammingError == dbapi.ProgrammingError |
| assert sqlite.NotSupportedError == dbapi.NotSupportedError |
| |
| with sqlite.cursor() as cur: |
| assert cur.arraysize == 1 |
| assert cur.connection is sqlite |
| assert cur.description is None |
| assert cur.rowcount == -1 |
| |
| |
| @pytest.mark.sqlite |
| def test_info(sqlite): |
| info = sqlite.adbc_get_info() |
| assert set(info.keys()) == { |
| "driver_arrow_version", |
| "driver_name", |
| "driver_version", |
| "vendor_name", |
| "vendor_version", |
| } |
| assert info["driver_name"] == "ADBC SQLite Driver" |
| assert info["vendor_name"] == "SQLite" |
| |
| |
| @pytest.mark.sqlite |
| def test_get_underlying(sqlite): |
| assert sqlite.adbc_database |
| assert sqlite.adbc_connection |
| with sqlite.cursor() as cur: |
| assert cur.adbc_statement |
| |
| |
| @pytest.mark.sqlite |
| def test_clone(sqlite): |
| with sqlite.adbc_clone() as sqlite2: |
| with sqlite2.cursor() as cur: |
| cur.execute("CREATE TABLE temporary (ints)") |
| cur.execute("INSERT INTO temporary VALUES (1)") |
| sqlite2.commit() |
| |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT * FROM temporary") |
| assert cur.fetchone() == (1,) |
| |
| |
| @pytest.mark.sqlite |
| def test_get_objects(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("CREATE TABLE temporary (ints)") |
| cur.execute("INSERT INTO temporary VALUES (1)") |
| metadata = ( |
| sqlite.adbc_get_objects(table_name_filter="temporary").read_all().to_pylist() |
| ) |
| assert len(metadata) == 1 |
| assert metadata[0]["catalog_name"] == "main" |
| schemas = metadata[0]["catalog_db_schemas"] |
| assert len(schemas) == 1 |
| assert schemas[0]["db_schema_name"] == "" |
| tables = schemas[0]["db_schema_tables"] |
| assert len(tables) == 1 |
| assert tables[0]["table_name"] == "temporary" |
| assert tables[0]["table_type"] == "table" |
| assert tables[0]["table_columns"][0]["column_name"] == "ints" |
| assert tables[0]["table_columns"][0]["ordinal_position"] == 1 |
| assert tables[0]["table_constraints"] == [] |
| |
| |
| @pytest.mark.sqlite |
| def test_get_table_schema(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("CREATE TABLE temporary (ints)") |
| cur.execute("INSERT INTO temporary VALUES (1)") |
| assert sqlite.adbc_get_table_schema("temporary") == pyarrow.schema( |
| [("ints", pyarrow.int64())] |
| ) |
| |
| |
| @pytest.mark.sqlite |
| def test_get_table_types(sqlite): |
| assert sqlite.adbc_get_table_types() == ["table", "view"] |
| |
| |
| class ArrayWrapper: |
| def __init__(self, array): |
| self.array = array |
| |
| def __arrow_c_array__(self, requested_schema=None): |
| return self.array.__arrow_c_array__(requested_schema=requested_schema) |
| |
| |
| class StreamWrapper: |
| def __init__(self, stream): |
| self.stream = stream |
| |
| def __arrow_c_stream__(self, requested_schema=None): |
| return self.stream.__arrow_c_stream__(requested_schema=requested_schema) |
| |
| |
| @pytest.mark.parametrize( |
| "data", |
| [ |
| lambda: pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"]), |
| lambda: pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]), |
| lambda: pyarrow.table( |
| [[1, 2], ["foo", ""]], names=["ints", "strs"] |
| ).to_reader(), |
| lambda: ArrayWrapper( |
| pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"]) |
| ), |
| lambda: StreamWrapper( |
| pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]) |
| ), |
| lambda: pyarrow.table( |
| [[1, 2], ["foo", ""]], names=["ints", "strs"] |
| ).__arrow_c_stream__(), |
| ], |
| ) |
| @pytest.mark.sqlite |
| def test_ingest(data, sqlite): |
| with sqlite.cursor() as cur: |
| cur.adbc_ingest("bulk_ingest", data()) |
| |
| with pytest.raises(dbapi.Error): |
| cur.adbc_ingest("bulk_ingest", data()) |
| |
| cur.adbc_ingest("bulk_ingest", data(), mode="append") |
| |
| with pytest.raises(dbapi.Error): |
| cur.adbc_ingest("nonexistent", data(), mode="append") |
| |
| with pytest.raises(ValueError): |
| cur.adbc_ingest("bulk_ingest", data(), mode="invalid") |
| |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT * FROM bulk_ingest") |
| assert cur.fetchone() == (1, "foo") |
| assert cur.fetchone() == (2, "") |
| assert cur.fetchone() == (1, "foo") |
| assert cur.fetchone() == (2, "") |
| |
| |
| @pytest.mark.sqlite |
| def test_partitions(sqlite): |
| with pytest.raises(dbapi.NotSupportedError): |
| with sqlite.cursor() as cur: |
| cur.adbc_execute_partitions("SELECT 1") |
| |
| |
| @pytest.mark.sqlite |
| def test_query_fetch_py(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT 1, 'foo' AS foo, 2.0") |
| assert cur.description == [ |
| ("1", dbapi.NUMBER, None, None, None, None, None), |
| ("foo", dbapi.STRING, None, None, None, None, None), |
| ("2.0", dbapi.NUMBER, None, None, None, None, None), |
| ] |
| assert cur.rownumber == 0 |
| assert cur.fetchone() == (1, "foo", 2.0) |
| assert cur.rownumber == 1 |
| assert cur.fetchone() is None |
| |
| cur.execute("SELECT 1, 'foo', 2.0") |
| assert cur.fetchmany() == [(1, "foo", 2.0)] |
| assert cur.fetchmany() == [] |
| |
| cur.execute("SELECT 1, 'foo', 2.0") |
| assert cur.fetchall() == [(1, "foo", 2.0)] |
| assert cur.fetchall() == [] |
| |
| cur.execute("SELECT 1, 'foo', 2.0") |
| assert list(cur) == [(1, "foo", 2.0)] |
| |
| |
| @pytest.mark.sqlite |
| def test_query_fetch_arrow(sqlite): |
| with sqlite.cursor() as cur: |
| with pytest.raises(sqlite.ProgrammingError): |
| cur.fetch_arrow() |
| |
| cur.execute("SELECT 1, 'foo' AS foo, 2.0") |
| capsule = cur.fetch_arrow().__arrow_c_stream__() |
| reader = pyarrow.RecordBatchReader._import_from_c_capsule(capsule) |
| assert reader.read_all() == pyarrow.table( |
| { |
| "1": [1], |
| "foo": ["foo"], |
| "2.0": [2.0], |
| } |
| ) |
| |
| with pytest.raises(sqlite.ProgrammingError): |
| cur.fetch_arrow() |
| |
| |
| @pytest.mark.sqlite |
| def test_query_fetch_arrow_3543(sqlite): |
| # Regression test for https://github.com/apache/arrow-adbc/issues/3543 |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT 1, 'foo' AS foo, 2.0") |
| |
| # This should not consume the result |
| assert cur.description == [ |
| ("1", dbapi.NUMBER, None, None, None, None, None), |
| ("foo", dbapi.STRING, None, None, None, None, None), |
| ("2.0", dbapi.NUMBER, None, None, None, None, None), |
| ] |
| |
| capsule = cur.fetch_arrow().__arrow_c_stream__() |
| reader = pyarrow.RecordBatchReader._import_from_c_capsule(capsule) |
| assert reader.read_all() == pyarrow.table( |
| { |
| "1": [1], |
| "foo": ["foo"], |
| "2.0": [2.0], |
| } |
| ) |
| |
| |
| @pytest.mark.sqlite |
| def test_query_fetch_arrow_table(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT 1, 'foo' AS foo, 2.0") |
| assert cur.fetch_arrow_table() == pyarrow.table( |
| { |
| "1": [1], |
| "foo": ["foo"], |
| "2.0": [2.0], |
| } |
| ) |
| |
| |
| @pytest.mark.sqlite |
| def test_query_fetch_df(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT 1, 'foo' AS foo, 2.0") |
| assert_frame_equal( |
| cur.fetch_df(), |
| pandas.DataFrame( |
| { |
| "1": [1], |
| "foo": ["foo"], |
| "2.0": [2.0], |
| } |
| ), |
| ) |
| |
| |
| @pytest.mark.sqlite |
| @pytest.mark.parametrize( |
| "parameters", |
| [ |
| (1.0, 2), |
| pyarrow.record_batch([[1.0], [2]], names=["float", "int"]), |
| pyarrow.table([[1.0], [2]], names=["float", "int"]), |
| ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])), |
| StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])), |
| ], |
| ) |
| def test_execute_parameters(sqlite, parameters): |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT ? + 1, ?", parameters) |
| assert cur.fetchall() == [(2.0, 2)] |
| |
| |
| @pytest.mark.sqlite |
| def test_execute_parameters_name(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT @a + 1, @b", {"@b": 2, "@a": 1}) |
| assert cur.fetchall() == [(2, 2)] |
| |
| # Ensure the state of the cursor isn't affected |
| cur.execute("SELECT ?2 + 1, ?1", [2, 1]) |
| assert cur.fetchall() == [(2, 2)] |
| |
| cur.execute("SELECT @a + 1, @b + @b", {"@b": 2, "@a": 1}) |
| assert cur.fetchall() == [(2, 4)] |
| |
| data = pyarrow.record_batch([[1.0], [2]], names=["float", "int"]) |
| cur.adbc_ingest("ingest_tester", data) |
| cur.execute("SELECT * FROM ingest_tester") |
| assert cur.fetchall() == [(1.0, 2)] |
| |
| |
| @pytest.mark.sqlite |
| def test_executemany_parameters_name(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("CREATE TABLE executemany_params (a, b)") |
| |
| cur.executemany( |
| "INSERT INTO executemany_params VALUES (@a, @b)", |
| [{"@b": 2, "@a": 1}, {"@b": 3, "@a": 2}], |
| ) |
| cur.executemany( |
| "INSERT INTO executemany_params VALUES (?, ?)", [(3, 4), (4, 5)] |
| ) |
| |
| cur.execute("SELECT * FROM executemany_params ORDER BY a ASC") |
| assert cur.fetchall() == [(1, 2), (2, 3), (3, 4), (4, 5)] |
| |
| |
| @pytest.mark.sqlite |
| @pytest.mark.parametrize( |
| "parameters", |
| [ |
| [(1, "a"), (3, None)], |
| pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]), |
| pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]), |
| pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0], |
| ArrayWrapper( |
| pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]) |
| ), |
| StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])), |
| ((x, y) for x, y in ((1, "a"), (3, None))), |
| ], |
| ) |
| def test_executemany_parameters(sqlite, parameters): |
| with sqlite.cursor() as cur: |
| cur.execute("DROP TABLE IF EXISTS executemany") |
| cur.execute("CREATE TABLE executemany (int, str)") |
| cur.executemany("INSERT INTO executemany VALUES (? * 2, ?)", parameters) |
| cur.execute("SELECT * FROM executemany ORDER BY int ASC") |
| assert cur.fetchall() == [(2, "a"), (6, None)] |
| |
| |
| @pytest.mark.sqlite |
| @pytest.mark.parametrize( |
| "parameters", |
| [ |
| [], |
| pyarrow.record_batch([[]], schema=pyarrow.schema([("v", pyarrow.int64())])), |
| pyarrow.table([[]], schema=pyarrow.schema([("v", pyarrow.int64())])), |
| ], |
| ) |
| def test_executemany_empty(sqlite, parameters): |
| # Regression test for https://github.com/apache/arrow-adbc/issues/3319 |
| with sqlite.cursor() as cur: |
| # With an empty sequence, it should be the same as not executing the |
| # query at all. |
| cur.execute("DROP TABLE IF EXISTS executemany") |
| cur.execute("CREATE TABLE executemany (v)") |
| cur.executemany("INSERT INTO executemany VALUES (?)", parameters) |
| cur.execute("SELECT * FROM executemany") |
| assert cur.fetchall() == [] |
| |
| |
| @pytest.mark.sqlite |
| def test_executemany_none(sqlite): |
| # Regression test for https://github.com/apache/arrow-adbc/issues/3319 |
| with sqlite.cursor() as cur: |
| # With None, it should be the same as executing the query once. |
| cur.execute("DROP TABLE IF EXISTS executemany") |
| cur.execute("CREATE TABLE executemany (v)") |
| with pytest.raises(sqlite.Error): |
| cur.executemany("INSERT INTO executemany VALUES (?)", None) |
| |
| |
| @pytest.mark.sqlite |
| def test_query_substrait(sqlite): |
| with sqlite.cursor() as cur: |
| with pytest.raises(dbapi.NotSupportedError): |
| cur.execute(b"Substrait plan") |
| |
| |
| @pytest.mark.sqlite |
| def test_executemany(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("CREATE TABLE foo (a, b)") |
| cur.executemany( |
| "INSERT INTO foo VALUES (?, ?)", |
| [ |
| (1, 2), |
| (3, 4), |
| (5, 6), |
| ], |
| ) |
| cur.execute("SELECT COUNT(*) FROM foo") |
| assert cur.fetchone() == (3,) |
| cur.execute("SELECT * FROM foo ORDER BY a ASC") |
| assert cur.rownumber == 0 |
| assert next(cur) == (1, 2) |
| assert cur.rownumber == 1 |
| assert next(cur) == (3, 4) |
| assert cur.rownumber == 2 |
| assert next(cur) == (5, 6) |
| |
| |
| @pytest.mark.sqlite |
| def test_fetch_record_batch(sqlite): |
| dataset = [ |
| [1, 2], |
| [3, 4], |
| [5, 6], |
| [7, 8], |
| [9, 10], |
| ] |
| with sqlite.cursor() as cur: |
| cur.execute("CREATE TABLE foo (a, b)") |
| cur.executemany( |
| "INSERT INTO foo VALUES (?, ?)", |
| dataset, |
| ) |
| cur.execute("SELECT * FROM foo") |
| rbr = cur.fetch_record_batch() |
| assert rbr.read_pandas().values.tolist() == dataset |
| |
| |
| @pytest.mark.sqlite |
| def test_fetch_empty(sqlite): |
| with sqlite.cursor() as cur: |
| cur.execute("CREATE TABLE foo (bar)") |
| cur.execute("SELECT * FROM foo") |
| assert cur.fetchall() == [] |
| |
| |
| @pytest.mark.sqlite |
| def test_reader(sqlite, tmp_path) -> None: |
| # Regression test for https://github.com/apache/arrow-adbc/issues/1523 |
| with sqlite.cursor() as cur: |
| cur.execute("SELECT 1") |
| reader = cur.fetch_record_batch() |
| pyarrow.dataset.write_dataset(reader, tmp_path, format="parquet") |
| |
| |
| @pytest.mark.sqlite |
| def test_prepare(sqlite): |
| with sqlite.cursor() as cur: |
| schema = cur.adbc_prepare("SELECT 1") |
| assert schema == pyarrow.schema([]) |
| |
| schema = cur.adbc_prepare("SELECT 1 + ?") |
| assert schema == pyarrow.schema([("0", "null")]) |
| |
| cur.execute("SELECT 1 + ?", (1,)) |
| assert cur.fetchone() == (2,) |
| |
| |
| @pytest.mark.sqlite |
| def test_close_warning(sqlite): |
| with pytest.warns( |
| ResourceWarning, |
| match=r"A adbc_driver_manager.dbapi.Cursor was not explicitly close\(\)d", |
| ): |
| cur = sqlite.cursor() |
| del cur |
| |
| with pytest.warns( |
| ResourceWarning, |
| match=r"A adbc_driver_manager.dbapi.Connection was not explicitly close\(\)d", |
| ): |
| conn = dbapi.connect(driver="adbc_driver_sqlite") |
| del conn |
| |
| |
| def _execute_schema(cursor): |
| try: |
| cursor.adbc_execute_schema("select 1") |
| except dbapi.NotSupportedError: |
| pass |
| |
| |
| @pytest.mark.sqlite |
| @pytest.mark.parametrize( |
| "op", |
| [ |
| pytest.param(lambda cursor: cursor.execute("SELECT 1"), id="execute"), |
| pytest.param( |
| lambda cursor: cursor.executemany("SELECT ?", [[1]]), id="executemany" |
| ), |
| pytest.param( |
| lambda cursor: cursor.adbc_ingest( |
| "test_release", |
| pyarrow.table([[1]], names=["ints"]), |
| mode="create_append", |
| ), |
| id="ingest", |
| ), |
| pytest.param(_execute_schema, id="execute_schema"), |
| pytest.param(lambda cursor: cursor.adbc_prepare("select 1"), id="prepare"), |
| pytest.param( |
| lambda cursor: cursor.executescript("select 1"), id="executescript" |
| ), |
| ], |
| ) |
| def test_release(sqlite, op) -> None: |
| # Regression test. Ensure that subsequent operations free results of |
| # earlier operations. |
| with sqlite.cursor() as cur: |
| cur.execute("select 1") |
| # Do _not_ fetch the data so it is never imported. |
| assert cur._results._handle.is_valid |
| handle = cur._results._handle |
| |
| op(cur) |
| if handle: |
| # The original handle (if it exists) should have been released |
| assert not handle.is_valid |
| |
| |
| def test_driver_path(): |
| with pytest.raises( |
| dbapi.ProgrammingError, |
| match="(dlopen|LoadLibraryExW).*failed:", |
| ): |
| with dbapi.connect(driver=pathlib.Path("/tmp/thisdriverdoesnotexist")): |
| pass |
| |
| |
| @pytest.mark.sqlite |
| def test_dbapi_extensions(sqlite): |
| with sqlite.execute("SELECT ?", (1,)) as cur: |
| assert cur.fetchone() == (1,) |
| assert cur.fetchone() is None |
| |
| assert cur.execute("SELECT 2").fetchall() == [(2,)] |
| |
| with sqlite.cursor() as cur: |
| assert cur.execute("SELECT 1").fetchall() == [(1,)] |
| assert cur.execute("SELECT 42").fetchall() == [(42,)] |
| |
| |
| @pytest.mark.sqlite |
| def test_connect(tmp_path: pathlib.Path, monkeypatch) -> None: |
| with dbapi.connect(driver="adbc_driver_sqlite") as conn: |
| with conn.cursor() as cur: |
| cur.execute("SELECT 1") |
| assert cur.fetchone() == (1,) |
| |
| # https://github.com/apache/arrow-adbc/issues/3517: allow positional |
| # argument |
| with dbapi.connect("adbc_driver_sqlite") as conn: |
| with conn.cursor() as cur: |
| cur.execute("SELECT 1") |
| assert cur.fetchone() == (1,) |
| |
| # https://github.com/apache/arrow-adbc/issues/3517: allow URI argument |
| db = tmp_path / "test.db" |
| with dbapi.connect("adbc_driver_sqlite", db.as_uri()) as conn: |
| with conn.cursor() as cur: |
| cur.execute("CREATE TABLE foo (a)") |
| cur.execute("INSERT INTO foo VALUES (1)") |
| conn.commit() |
| |
| with dbapi.connect(driver="adbc_driver_sqlite", uri=db.as_uri()) as conn: |
| with conn.cursor() as cur: |
| cur.execute("SELECT * FROM foo") |
| assert cur.fetchone() == (1,) |
| |
| monkeypatch.setenv("ADBC_DRIVER_PATH", tmp_path) |
| with (tmp_path / "foobar.toml").open("w") as f: |
| f.write( |
| """ |
| [Driver] |
| shared = "adbc_driver_foobar" |
| """ |
| ) |
| # Just check that the driver gets detected and loaded (should fail) |
| with pytest.raises(dbapi.ProgrammingError, match="NOT_FOUND"): |
| with dbapi.connect("foobar://localhost:5439"): |
| pass |