blob: 94cd9f82d06d01db143ddd5712431d211cb3c7de [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import string
from pathlib import Path
from typing import Generator
import numpy
import pyarrow
import pyarrow.dataset
import pytest
from adbc_driver_postgresql import StatementOptions, dbapi
@pytest.fixture
def postgres(postgres_uri: str) -> Generator[dbapi.Connection, None, None]:
with dbapi.connect(postgres_uri) as conn:
yield conn
def test_conn_current_catalog(postgres: dbapi.Connection) -> None:
assert postgres.adbc_current_catalog != ""
def test_conn_current_db_schema(postgres: dbapi.Connection) -> None:
assert postgres.adbc_current_db_schema == "public"
def test_conn_change_db_schema(postgres: dbapi.Connection) -> None:
assert postgres.adbc_current_db_schema == "public"
with postgres.cursor() as cur:
cur.execute("CREATE SCHEMA IF NOT EXISTS dbapischema")
assert postgres.adbc_current_db_schema == "public"
postgres.adbc_current_db_schema = "dbapischema"
assert postgres.adbc_current_db_schema == "dbapischema"
def test_conn_get_info(postgres: dbapi.Connection) -> None:
info = postgres.adbc_get_info()
assert info["driver_name"] == "ADBC PostgreSQL Driver"
assert info["driver_adbc_version"] == 1_001_000
assert info["vendor_name"] == "PostgreSQL"
def test_query_batch_size(postgres: dbapi.Connection):
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_batch_size")
cur.execute("CREATE TABLE test_batch_size (ints INT)")
cur.execute(
"""
INSERT INTO test_batch_size (ints)
SELECT generated :: INT
FROM GENERATE_SERIES(1, 65536) temp(generated)
"""
)
cur.execute("SELECT * FROM test_batch_size")
table = cur.fetch_arrow_table()
assert len(table.to_batches()) == 1
cur.adbc_statement.set_options(
**{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "1"}
)
assert (
cur.adbc_statement.get_option_int(
StatementOptions.BATCH_SIZE_HINT_BYTES.value
)
== 1
)
cur.execute("SELECT * FROM test_batch_size")
table = cur.fetch_arrow_table()
assert len(table.to_batches()) == 65536
cur.adbc_statement.set_options(
**{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "4096"}
)
assert (
cur.adbc_statement.get_option_int(
StatementOptions.BATCH_SIZE_HINT_BYTES.value
)
== 4096
)
cur.execute("SELECT * FROM test_batch_size")
table = cur.fetch_arrow_table()
assert 64 <= len(table.to_batches()) <= 256
def test_query_cancel(postgres: dbapi.Connection) -> None:
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_batch_size")
cur.execute("CREATE TABLE test_batch_size (ints INT)")
cur.execute(
"""
INSERT INTO test_batch_size (ints)
SELECT generated :: INT
FROM GENERATE_SERIES(1, 1048576) temp(generated)
"""
)
postgres.commit()
# Ensure different ways of reading all raise the desired error
with postgres.cursor() as cur:
cur.execute("SELECT * FROM test_batch_size")
cur.adbc_cancel()
with pytest.raises(postgres.OperationalError, match="canceling statement"):
cur.fetchone()
postgres.rollback()
with postgres.cursor() as cur:
cur.execute("SELECT * FROM test_batch_size")
cur.adbc_cancel()
with pytest.raises(postgres.OperationalError, match="canceling statement"):
cur.fetch_arrow_table()
postgres.rollback()
with postgres.cursor() as cur:
cur.execute("SELECT * FROM test_batch_size")
cur.adbc_cancel()
with pytest.raises(postgres.OperationalError, match="canceling statement"):
cur.fetch_df()
def test_query_execute_schema(postgres: dbapi.Connection) -> None:
with postgres.cursor() as cur:
schema = cur.adbc_execute_schema("SELECT 1 AS foo")
assert schema == pyarrow.schema([("foo", "int32")])
def test_query_invalid(postgres: dbapi.Connection) -> None:
with postgres.cursor() as cur:
with pytest.raises(
postgres.ProgrammingError, match="failed to prepare query"
) as excinfo:
cur.execute("SELECT * FROM tabledoesnotexist")
assert excinfo.value.sqlstate == "42P01"
assert len(excinfo.value.details) > 0
def test_query_trivial(postgres: dbapi.Connection):
with postgres.cursor() as cur:
cur.execute("SELECT 1")
assert cur.fetchone() == (1,)
def test_stmt_ingest(postgres: dbapi.Connection) -> None:
table = pyarrow.table(
[
[1, 2, 3],
["a", None, "b"],
],
names=["ints", "strs"],
)
double_table = pyarrow.table(
[
[1, 1, 2, 2, 3, 3],
["a", "a", None, None, "b", "b"],
],
names=["ints", "strs"],
)
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_ingest")
with pytest.raises(
postgres.ProgrammingError, match='"public.test_ingest" does not exist'
):
cur.adbc_ingest("test_ingest", table, mode="append")
postgres.rollback()
cur.adbc_ingest("test_ingest", table, mode="replace")
cur.execute("SELECT * FROM test_ingest ORDER BY ints")
assert cur.fetch_arrow_table() == table
with pytest.raises(
postgres.ProgrammingError, match='"test_ingest" already exists'
):
cur.adbc_ingest("test_ingest", table, mode="create")
cur.adbc_ingest("test_ingest", table, mode="create_append")
cur.execute("SELECT * FROM test_ingest ORDER BY ints")
assert cur.fetch_arrow_table() == double_table
cur.adbc_ingest("test_ingest", table, mode="replace")
cur.execute("SELECT * FROM test_ingest ORDER BY ints")
assert cur.fetch_arrow_table() == table
cur.execute("DROP TABLE IF EXISTS test_ingest")
cur.adbc_ingest("test_ingest", table, mode="create_append")
cur.execute("SELECT * FROM test_ingest ORDER BY ints")
assert cur.fetch_arrow_table() == table
cur.execute("DROP TABLE IF EXISTS test_ingest")
cur.adbc_ingest("test_ingest", table, mode="create")
cur.execute("SELECT * FROM test_ingest ORDER BY ints")
assert cur.fetch_arrow_table() == table
def test_stmt_ingest_dataset(postgres: dbapi.Connection, tmp_path: Path) -> None:
# Regression test for https://github.com/apache/arrow-adbc/issues/1310
table = pyarrow.table(
[
[1, 1, 2, 2, 3, 3],
["a", "a", None, None, "b", "b"],
],
schema=pyarrow.schema([("ints", "int32"), ("strs", "string")]),
)
pyarrow.dataset.write_dataset(
table, tmp_path, format="parquet", partitioning=["ints"]
)
ds = pyarrow.dataset.dataset(tmp_path, format="parquet", partitioning=["ints"])
with postgres.cursor() as cur:
for item in (
lambda: ds,
lambda: ds.scanner(),
lambda: ds.scanner().to_reader(),
lambda: ds.scanner().to_table(),
):
cur.execute("DROP TABLE IF EXISTS test_ingest")
cur.adbc_ingest(
"test_ingest",
item(),
mode="create_append",
)
cur.execute("SELECT ints, strs FROM test_ingest ORDER BY ints")
assert cur.fetch_arrow_table() == table
def test_stmt_ingest_multi(postgres: dbapi.Connection) -> None:
# Regression test for https://github.com/apache/arrow-adbc/issues/1310
table = pyarrow.table(
[
[1, 1, 2, 2, 3, 3],
["a", "a", None, None, "b", "b"],
],
names=["ints", "strs"],
)
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_ingest")
cur.adbc_ingest(
"test_ingest",
table.to_batches(max_chunksize=2),
mode="create_append",
)
cur.execute("SELECT * FROM test_ingest ORDER BY ints")
assert cur.fetch_arrow_table() == table
def test_ddl(postgres: dbapi.Connection):
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_ddl")
assert cur.fetchone() is None
cur.execute("CREATE TABLE test_ddl (ints INT)")
assert cur.fetchone() is None
cur.execute("INSERT INTO test_ddl VALUES (1) RETURNING ints")
assert cur.fetchone() == (1,)
cur.execute("SELECT * FROM test_ddl")
assert cur.fetchone() == (1,)
def test_crash(postgres: dbapi.Connection) -> None:
with postgres.cursor() as cur:
cur.execute("SELECT 1")
assert cur.fetchone() == (1,)
def test_reuse(postgres: dbapi.Connection) -> None:
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_batch_size")
cur.execute("CREATE TABLE test_batch_size (ints INT)")
cur.execute(
"""
INSERT INTO test_batch_size (ints)
SELECT generated :: INT
FROM GENERATE_SERIES(1, 65536) temp(generated)
"""
)
cur.execute("SELECT * FROM test_batch_size ORDER BY ints ASC")
assert cur.fetchone() == (1,)
cur.execute("SELECT 1")
assert cur.fetchone() == (1,)
cur.execute("SELECT 2")
assert cur.fetchone() == (2,)
def test_ingest(postgres: dbapi.Connection) -> None:
table = pyarrow.Table.from_pydict({"numbers": [1, 2], "letters": ["a", "b"]})
with postgres.cursor() as cur:
cur.adbc_ingest("foo", table, mode="replace", db_schema_name="public")
cur.execute("SELECT * FROM public.foo")
assert cur.fetch_arrow_table() == table
with pytest.raises(dbapi.NotSupportedError):
cur.adbc_ingest("foo", table, catalog_name="main")
def test_ingest_schema(postgres: dbapi.Connection) -> None:
table = pyarrow.Table.from_pydict({"numbers": [1, 2], "letters": ["a", "b"]})
with postgres.cursor() as cur:
cur.execute("CREATE SCHEMA IF NOT EXISTS testschema")
cur.execute("DROP TABLE IF EXISTS testschema.foo")
postgres.commit()
cur.adbc_ingest("foo", table, mode="create", db_schema_name="testschema")
cur.execute("SELECT * FROM testschema.foo ORDER BY numbers")
assert cur.fetch_arrow_table() == table
def test_ingest_temporary(postgres: dbapi.Connection) -> None:
table = pyarrow.Table.from_pydict(
{
"numbers": [1, 2],
"letters": ["a", "b"],
}
)
temp = pyarrow.Table.from_pydict(
{
"ints": [3, 4],
"strs": ["c", "d"],
}
)
table2 = pyarrow.Table.from_pydict(
{
"numbers": [1, 2, 1, 2],
"letters": ["a", "b", "a", "b"],
}
)
temp2 = pyarrow.Table.from_pydict(
{
"ints": [3, 4, 3, 4],
"strs": ["c", "d", "c", "d"],
}
)
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS public.temporary")
cur.execute("DROP TABLE IF EXISTS pg_temp.temporary")
cur.adbc_ingest("temporary", table, mode="create")
cur.adbc_ingest("temporary", temp, mode="create", temporary=True)
cur.execute("SELECT * FROM public.temporary")
assert cur.fetch_arrow_table() == table
cur.execute("SELECT * FROM pg_temp.temporary")
assert cur.fetch_arrow_table() == temp
cur.execute("SELECT * FROM temporary")
assert cur.fetch_arrow_table() == temp
cur.adbc_ingest("temporary", table, mode="append")
cur.adbc_ingest("temporary", temp, mode="append", temporary=True)
cur.execute("SELECT * FROM public.temporary")
assert cur.fetch_arrow_table() == table2
cur.execute("SELECT * FROM pg_temp.temporary")
assert cur.fetch_arrow_table() == temp2
cur.execute("SELECT * FROM temporary")
assert cur.fetch_arrow_table() == temp2
cur.adbc_ingest("temporary", table, mode="replace")
cur.adbc_ingest("temporary", temp, mode="replace", temporary=True)
cur.execute("SELECT * FROM public.temporary")
assert cur.fetch_arrow_table() == table
cur.execute("SELECT * FROM pg_temp.temporary")
assert cur.fetch_arrow_table() == temp
cur.execute("SELECT * FROM temporary")
assert cur.fetch_arrow_table() == temp
cur.adbc_ingest("temporary", table, mode="create_append")
cur.adbc_ingest("temporary", temp, mode="create_append", temporary=True)
cur.execute("SELECT * FROM public.temporary")
assert cur.fetch_arrow_table() == table2
cur.execute("SELECT * FROM pg_temp.temporary")
assert cur.fetch_arrow_table() == temp2
cur.execute("SELECT * FROM temporary")
assert cur.fetch_arrow_table() == temp2
def test_ingest_large(postgres: dbapi.Connection) -> None:
"""Regression test for #1921."""
# More than 1 GiB of data in one batch
arr = pyarrow.array(numpy.random.randint(-100, 100, size=4_000_000))
batch = pyarrow.RecordBatch.from_pydict(
{char: arr for char in string.ascii_lowercase}
)
table = pyarrow.Table.from_batches([batch] * 4)
with postgres.cursor() as cur:
cur.adbc_ingest("test_ingest_large", table, mode="replace")