blob: 1f9ecbfc398a8f341af7dd0d5bd8e64695ab7c90 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import datafusion as dfn
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from datafusion import SessionContext, Table
# Note we take in `database` as a variable even though we don't use
# it because that will cause the fixture to set up the context with
# the tables we need.
def test_basic(ctx, database):
with pytest.raises(KeyError):
ctx.catalog("non-existent")
default = ctx.catalog()
assert default.names() == {"public"}
for db in [default.schema("public"), default.schema()]:
assert db.names() == {"csv1", "csv", "csv2"}
table = db.table("csv")
assert table.kind == "physical"
assert table.schema == pa.schema(
[
pa.field("int", pa.int64(), nullable=True),
pa.field("str", pa.string(), nullable=True),
pa.field("float", pa.float64(), nullable=True),
]
)
def create_dataset() -> Table:
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
dataset = ds.dataset([batch])
return Table.from_dataset(dataset)
class CustomSchemaProvider(dfn.catalog.SchemaProvider):
def __init__(self):
self.tables = {"table1": create_dataset()}
def table_names(self) -> set[str]:
return set(self.tables.keys())
def register_table(self, name: str, table: Table):
self.tables[name] = table
def deregister_table(self, name, cascade: bool = True):
del self.tables[name]
def table(self, name: str) -> Table | None:
return self.tables[name]
def table_exist(self, name: str) -> bool:
return name in self.tables
class CustomCatalogProvider(dfn.catalog.CatalogProvider):
def __init__(self):
self.schemas = {"my_schema": CustomSchemaProvider()}
def schema_names(self) -> set[str]:
return set(self.schemas.keys())
def schema(self, name: str):
return self.schemas[name]
def register_schema(self, name: str, schema: dfn.catalog.Schema):
self.schemas[name] = schema
def deregister_schema(self, name, cascade: bool):
del self.schemas[name]
def test_python_catalog_provider(ctx: SessionContext):
ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())
# Check the default catalog provider
assert ctx.catalog("datafusion").names() == {"public"}
my_catalog = ctx.catalog("my_catalog")
assert my_catalog.names() == {"my_schema"}
my_catalog.register_schema("second_schema", CustomSchemaProvider())
assert my_catalog.schema_names() == {"my_schema", "second_schema"}
my_catalog.deregister_schema("my_schema")
assert my_catalog.schema_names() == {"second_schema"}
def test_in_memory_providers(ctx: SessionContext):
catalog = dfn.catalog.Catalog.memory_catalog()
ctx.register_catalog_provider("in_mem_catalog", catalog)
assert ctx.catalog_names() == {"datafusion", "in_mem_catalog"}
schema = dfn.catalog.Schema.memory_schema()
catalog.register_schema("in_mem_schema", schema)
schema.register_table("my_table", create_dataset())
batches = ctx.sql("select * from in_mem_catalog.in_mem_schema.my_table").collect()
assert len(batches) == 1
assert batches[0].column(0) == pa.array([1, 2, 3])
assert batches[0].column(1) == pa.array([4, 5, 6])
def test_python_schema_provider(ctx: SessionContext):
catalog = ctx.catalog()
catalog.deregister_schema("public")
catalog.register_schema("test_schema1", CustomSchemaProvider())
assert catalog.names() == {"test_schema1"}
catalog.register_schema("test_schema2", CustomSchemaProvider())
catalog.deregister_schema("test_schema1")
assert catalog.names() == {"test_schema2"}
def test_python_table_provider(ctx: SessionContext):
catalog = ctx.catalog()
catalog.register_schema("custom_schema", CustomSchemaProvider())
schema = catalog.schema("custom_schema")
assert schema.table_names() == {"table1"}
schema.deregister_table("table1")
schema.register_table("table2", create_dataset())
assert schema.table_names() == {"table2"}
# Use the default schema instead of our custom schema
schema = catalog.schema()
schema.register_table("table3", create_dataset())
assert schema.table_names() == {"table3"}
schema.deregister_table("table3")
schema.register_table("table4", create_dataset())
assert schema.table_names() == {"table4"}
def test_in_end_to_end_python_providers(ctx: SessionContext):
"""Test registering all python providers and running a query against them."""
all_catalog_names = [
"datafusion",
"custom_catalog",
"in_mem_catalog",
]
all_schema_names = [
"custom_schema",
"in_mem_schema",
]
ctx.register_catalog_provider(all_catalog_names[1], CustomCatalogProvider())
ctx.register_catalog_provider(
all_catalog_names[2], dfn.catalog.Catalog.memory_catalog()
)
for catalog_name in all_catalog_names:
catalog = ctx.catalog(catalog_name)
# Clean out previous schemas if they exist so we can start clean
for schema_name in catalog.schema_names():
catalog.deregister_schema(schema_name, cascade=False)
catalog.register_schema(all_schema_names[0], CustomSchemaProvider())
catalog.register_schema(all_schema_names[1], dfn.catalog.Schema.memory_schema())
for schema_name in all_schema_names:
schema = catalog.schema(schema_name)
for table_name in schema.table_names():
schema.deregister_table(table_name)
schema.register_table("test_table", create_dataset())
for catalog_name in all_catalog_names:
for schema_name in all_schema_names:
table_full_name = f"{catalog_name}.{schema_name}.test_table"
batches = ctx.sql(f"select * from {table_full_name}").collect()
assert len(batches) == 1
assert batches[0].column(0) == pa.array([1, 2, 3])
assert batches[0].column(1) == pa.array([4, 5, 6])