blob: 6f949f8cae2b201f24b722eb705ea725fb9f0a87 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::any::Any;
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use datafusion::catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
};
use datafusion::common::DataFusionError;
use datafusion::datasource::TableProvider;
use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
use pyo3::IntoPyObjectExt;
use crate::dataset::Dataset;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::table::PyTable;
use crate::utils::{validate_pycapsule, wait_for_future};
#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
#[derive(Clone)]
pub struct PyCatalog {
pub catalog: Arc<dyn CatalogProvider>,
}
#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)]
#[derive(Clone)]
pub struct PySchema {
pub schema: Arc<dyn SchemaProvider>,
}
impl From<Arc<dyn CatalogProvider>> for PyCatalog {
fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
Self { catalog }
}
}
impl From<Arc<dyn SchemaProvider>> for PySchema {
fn from(schema: Arc<dyn SchemaProvider>) -> Self {
Self { schema }
}
}
#[pymethods]
impl PyCatalog {
#[new]
fn new(catalog: PyObject) -> Self {
let catalog_provider =
Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
catalog_provider.into()
}
#[staticmethod]
fn memory_catalog() -> Self {
let catalog_provider =
Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
catalog_provider.into()
}
fn schema_names(&self) -> HashSet<String> {
self.catalog.schema_names().into_iter().collect()
}
#[pyo3(signature = (name="public"))]
fn schema(&self, name: &str) -> PyResult<PyObject> {
let schema = self
.catalog
.schema(name)
.ok_or(PyKeyError::new_err(format!(
"Schema with name {name} doesn't exist."
)))?;
Python::with_gil(|py| {
match schema
.as_any()
.downcast_ref::<RustWrappedPySchemaProvider>()
{
Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
None => PySchema::from(schema).into_py_any(py),
}
})
}
fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
let capsule = schema_provider
.getattr("__datafusion_schema_provider__")?
.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_schema_provider")?;
let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
let provider: ForeignSchemaProvider = provider.into();
Arc::new(provider) as Arc<dyn SchemaProvider>
} else {
match schema_provider.extract::<PySchema>() {
Ok(py_schema) => py_schema.schema,
Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
as Arc<dyn SchemaProvider>,
}
};
let _ = self
.catalog
.register_schema(name, provider)
.map_err(py_datafusion_err)?;
Ok(())
}
fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
let _ = self
.catalog
.deregister_schema(name, cascade)
.map_err(py_datafusion_err)?;
Ok(())
}
fn __repr__(&self) -> PyResult<String> {
let mut names: Vec<String> = self.schema_names().into_iter().collect();
names.sort();
Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
}
}
#[pymethods]
impl PySchema {
#[new]
fn new(schema_provider: PyObject) -> Self {
let schema_provider =
Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
schema_provider.into()
}
#[staticmethod]
fn memory_schema() -> Self {
let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
schema_provider.into()
}
#[getter]
fn table_names(&self) -> HashSet<String> {
self.schema.table_names().into_iter().collect()
}
fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
Ok(PyTable::from(table))
} else {
Err(PyDataFusionError::Common(format!(
"Table not found: {name}"
)))
}
}
fn __repr__(&self) -> PyResult<String> {
let mut names: Vec<String> = self.table_names().into_iter().collect();
names.sort();
Ok(format!("Schema(table_names=[{}])", names.join(";")))
}
fn register_table(&self, name: &str, table_provider: &Bound<'_, PyAny>) -> PyResult<()> {
let table = PyTable::new(table_provider)?;
let _ = self
.schema
.register_table(name.to_string(), table.table)
.map_err(py_datafusion_err)?;
Ok(())
}
fn deregister_table(&self, name: &str) -> PyResult<()> {
let _ = self
.schema
.deregister_table(name)
.map_err(py_datafusion_err)?;
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct RustWrappedPySchemaProvider {
schema_provider: PyObject,
owner_name: Option<String>,
}
impl RustWrappedPySchemaProvider {
pub fn new(schema_provider: PyObject) -> Self {
let owner_name = Python::with_gil(|py| {
schema_provider
.bind(py)
.getattr("owner_name")
.ok()
.map(|name| name.to_string())
});
Self {
schema_provider,
owner_name,
}
}
fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
Python::with_gil(|py| {
let provider = self.schema_provider.bind(py);
let py_table_method = provider.getattr("table")?;
let py_table = py_table_method.call((name,), None)?;
if py_table.is_none() {
return Ok(None);
}
let table = PyTable::new(&py_table)?;
Ok(Some(table.table))
})
}
}
#[async_trait]
impl SchemaProvider for RustWrappedPySchemaProvider {
fn owner_name(&self) -> Option<&str> {
self.owner_name.as_deref()
}
fn as_any(&self) -> &dyn Any {
self
}
fn table_names(&self) -> Vec<String> {
Python::with_gil(|py| {
let provider = self.schema_provider.bind(py);
provider
.getattr("table_names")
.and_then(|names| names.extract::<Vec<String>>())
.unwrap_or_else(|err| {
log::error!("Unable to get table_names: {err}");
Vec::default()
})
})
}
async fn table(
&self,
name: &str,
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
self.table_inner(name).map_err(to_datafusion_err)
}
fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
let py_table = PyTable::from(table);
Python::with_gil(|py| {
let provider = self.schema_provider.bind(py);
let _ = provider
.call_method1("register_table", (name, py_table))
.map_err(to_datafusion_err)?;
// Since the definition of `register_table` says that an error
// will be returned if the table already exists, there is no
// case where we want to return a table provider as output.
Ok(None)
})
}
fn deregister_table(
&self,
name: &str,
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
Python::with_gil(|py| {
let provider = self.schema_provider.bind(py);
let table = provider
.call_method1("deregister_table", (name,))
.map_err(to_datafusion_err)?;
if table.is_none() {
return Ok(None);
}
// If we can turn this table provider into a `Dataset`, return it.
// Otherwise, return None.
let dataset = match Dataset::new(&table, py) {
Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
Err(_) => None,
};
Ok(dataset)
})
}
fn table_exist(&self, name: &str) -> bool {
Python::with_gil(|py| {
let provider = self.schema_provider.bind(py);
provider
.call_method1("table_exist", (name,))
.and_then(|pyobj| pyobj.extract())
.unwrap_or(false)
})
}
}
#[derive(Debug)]
pub(crate) struct RustWrappedPyCatalogProvider {
pub(crate) catalog_provider: PyObject,
}
impl RustWrappedPyCatalogProvider {
pub fn new(catalog_provider: PyObject) -> Self {
Self { catalog_provider }
}
fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
Python::with_gil(|py| {
let provider = self.catalog_provider.bind(py);
let py_schema = provider.call_method1("schema", (name,))?;
if py_schema.is_none() {
return Ok(None);
}
if py_schema.hasattr("__datafusion_schema_provider__")? {
let capsule = provider
.getattr("__datafusion_schema_provider__")?
.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_schema_provider")?;
let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
let provider: ForeignSchemaProvider = provider.into();
Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
} else {
if let Ok(inner_schema) = py_schema.getattr("schema") {
if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
return Ok(Some(inner_schema.schema));
}
}
match py_schema.extract::<PySchema>() {
Ok(inner_schema) => Ok(Some(inner_schema.schema)),
Err(_) => {
let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
}
}
}
})
}
}
#[async_trait]
impl CatalogProvider for RustWrappedPyCatalogProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema_names(&self) -> Vec<String> {
Python::with_gil(|py| {
let provider = self.catalog_provider.bind(py);
provider
.getattr("schema_names")
.and_then(|names| names.extract::<Vec<String>>())
.unwrap_or_else(|err| {
log::error!("Unable to get schema_names: {err}");
Vec::default()
})
})
}
fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
self.schema_inner(name).unwrap_or_else(|err| {
log::error!("CatalogProvider schema returned error: {err}");
None
})
}
fn register_schema(
&self,
name: &str,
schema: Arc<dyn SchemaProvider>,
) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
// JRIGHT HERE
// let py_schema: PySchema = schema.into();
Python::with_gil(|py| {
let py_schema = match schema
.as_any()
.downcast_ref::<RustWrappedPySchemaProvider>()
{
Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
None => &PySchema::from(schema)
.into_py_any(py)
.map_err(to_datafusion_err)?,
};
let provider = self.catalog_provider.bind(py);
let schema = provider
.call_method1("register_schema", (name, py_schema))
.map_err(to_datafusion_err)?;
if schema.is_none() {
return Ok(None);
}
let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
as Arc<dyn SchemaProvider>;
Ok(Some(schema))
})
}
fn deregister_schema(
&self,
name: &str,
cascade: bool,
) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
Python::with_gil(|py| {
let provider = self.catalog_provider.bind(py);
let schema = provider
.call_method1("deregister_schema", (name, cascade))
.map_err(to_datafusion_err)?;
if schema.is_none() {
return Ok(None);
}
let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
as Arc<dyn SchemaProvider>;
Ok(Some(schema))
})
}
}
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyCatalog>()?;
m.add_class::<PySchema>()?;
m.add_class::<PyTable>()?;
Ok(())
}