blob: 344d5c703c40f6554e107b2cf2e5012971e67e36 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use object_store::ObjectStore;
use uuid::Uuid;
use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::prelude::*;
use parking_lot::RwLock;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
use crate::errors::DataFusionError;
use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::utils::wait_for_future;
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::config::ConfigOptions;
use datafusion::datasource::datasource::TableProvider;
use datafusion::datasource::MemTable;
use datafusion::execution::context::{SessionConfig, SessionContext};
use datafusion::prelude::{AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions};
/// `PySessionContext` is able to plan and execute DataFusion plans.
/// It has a powerful optimizer, a physical planner for local execution, and a
/// multi-threaded execution engine to perform the execution.
#[pyclass(name = "SessionContext", module = "datafusion", subclass, unsendable)]
pub(crate) struct PySessionContext {
ctx: SessionContext,
}
#[pymethods]
impl PySessionContext {
#[allow(clippy::too_many_arguments)]
#[args(
default_catalog = "\"datafusion\"",
default_schema = "\"public\"",
create_default_catalog_and_schema = "true",
information_schema = "false",
repartition_joins = "true",
repartition_aggregations = "true",
repartition_windows = "true",
parquet_pruning = "true",
target_partitions = "None",
config_options = "None"
)]
#[new]
fn new(
default_catalog: &str,
default_schema: &str,
create_default_catalog_and_schema: bool,
information_schema: bool,
repartition_joins: bool,
repartition_aggregations: bool,
repartition_windows: bool,
parquet_pruning: bool,
target_partitions: Option<usize>,
config_options: Option<HashMap<String, String>>,
) -> Self {
let mut options = ConfigOptions::from_env();
if let Some(hash_map) = config_options {
for (k, v) in &hash_map {
if let Ok(v) = v.parse::<bool>() {
options.set_bool(k, v);
} else if let Ok(v) = v.parse::<u64>() {
options.set_u64(k, v);
} else {
options.set_string(k, v);
}
}
}
let config_options = Arc::new(RwLock::new(options));
let mut cfg = SessionConfig::new()
.create_default_catalog_and_schema(create_default_catalog_and_schema)
.with_default_catalog_and_schema(default_catalog, default_schema)
.with_information_schema(information_schema)
.with_repartition_joins(repartition_joins)
.with_repartition_aggregations(repartition_aggregations)
.with_repartition_windows(repartition_windows)
.with_parquet_pruning(parquet_pruning);
// TODO we should add a `with_config_options` to `SessionConfig`
cfg.config_options = config_options;
let cfg_full = match target_partitions {
None => cfg,
Some(x) => cfg.with_target_partitions(x),
};
PySessionContext {
ctx: SessionContext::with_config(cfg_full),
}
}
/// Register a an object store with the given name
fn register_object_store(
&mut self,
scheme: &str,
store: &PyAny,
host: Option<&str>,
) -> PyResult<()> {
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
match StorageContexts::extract(store) {
Ok(store) => match store {
StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)),
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)),
StorageContexts::MicrosoftAzure(azure) => {
Ok((azure.inner, azure.container_name))
}
StorageContexts::LocalFileSystem(local) => Ok((local.inner, "".to_string())),
},
Err(_e) => Err(PyValueError::new_err("Invalid object store")),
};
// for most stores the "host" is the bucket name and can be inferred from the store
let (store, upstream_host) = res?;
// let users override the host to match the api signature from upstream
let derived_host = if let Some(host) = host {
host
} else {
&upstream_host
};
self.ctx
.runtime_env()
.register_object_store(scheme, derived_host, store);
Ok(())
}
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
Ok(PyDataFrame::new(df))
}
fn create_dataframe(
&mut self,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
) -> PyResult<PyDataFrame> {
let schema = partitions.0[0][0].schema();
let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?;
// generate a random (unique) name for this table
// table name cannot start with numeric digit
let name = "c".to_owned()
+ Uuid::new_v4()
.to_simple()
.encode_lower(&mut Uuid::encode_buffer());
self.ctx
.register_table(&*name, Arc::new(table))
.map_err(DataFusionError::from)?;
let table = self.ctx.table(&*name).map_err(DataFusionError::from)?;
let df = PyDataFrame::new(table);
Ok(df)
}
fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
self.ctx
.register_table(name, table.table())
.map_err(DataFusionError::from)?;
Ok(())
}
fn deregister_table(&mut self, name: &str) -> PyResult<()> {
self.ctx
.deregister_table(name)
.map_err(DataFusionError::from)?;
Ok(())
}
fn register_record_batches(
&mut self,
name: &str,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
) -> PyResult<()> {
let schema = partitions.0[0][0].schema();
let table = MemTable::try_new(schema, partitions.0)?;
self.ctx
.register_table(name, Arc::new(table))
.map_err(DataFusionError::from)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
#[args(
table_partition_cols = "vec![]",
parquet_pruning = "true",
file_extension = "\".parquet\""
)]
fn register_parquet(
&mut self,
name: &str,
path: &str,
table_partition_cols: Vec<(String, String)>,
parquet_pruning: bool,
file_extension: &str,
py: Python,
) -> PyResult<()> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.parquet_pruning(parquet_pruning);
options.file_extension = file_extension;
let result = self.ctx.register_parquet(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
#[args(
schema = "None",
has_header = "true",
delimiter = "\",\"",
schema_infer_max_records = "1000",
file_extension = "\".csv\""
)]
fn register_csv(
&mut self,
name: &str,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
py: Python,
) -> PyResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
"Delimiter must be a single character",
));
}
let mut options = CsvReadOptions::new()
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension);
options.schema = schema.as_ref().map(|x| &x.0);
let result = self.ctx.register_csv(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
Ok(())
}
// Registers a PyArrow.Dataset
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
self.ctx
.register_table(name, table)
.map_err(DataFusionError::from)?;
Ok(())
}
fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
self.ctx.register_udf(udf.function);
Ok(())
}
fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
self.ctx.register_udaf(udaf.function);
Ok(())
}
#[args(name = "\"datafusion\"")]
fn catalog(&self, name: &str) -> PyResult<PyCatalog> {
match self.ctx.catalog(name) {
Some(catalog) => Ok(PyCatalog::new(catalog)),
None => Err(PyKeyError::new_err(format!(
"Catalog with name {} doesn't exist.",
&name
))),
}
}
fn tables(&self) -> HashSet<String> {
#[allow(deprecated)]
self.ctx.tables().unwrap()
}
fn table(&self, name: &str) -> PyResult<PyDataFrame> {
Ok(PyDataFrame::new(self.ctx.table(name)?))
}
fn table_exist(&self, name: &str) -> PyResult<bool> {
Ok(self.ctx.table_exist(name)?)
}
fn empty_table(&self) -> PyResult<PyDataFrame> {
Ok(PyDataFrame::new(self.ctx.read_empty()?))
}
fn session_id(&self) -> PyResult<String> {
Ok(self.ctx.session_id())
}
#[allow(clippy::too_many_arguments)]
#[args(
schema = "None",
schema_infer_max_records = "1000",
file_extension = "\".json\"",
table_partition_cols = "vec![]"
)]
fn read_json(
&mut self,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
let mut options = NdJsonReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
options.schema = schema.map(|s| Arc::new(s.0));
options.schema_infer_max_records = schema_infer_max_records;
options.file_extension = file_extension;
let result = self.ctx.read_json(path, options);
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
Ok(PyDataFrame::new(df))
}
#[allow(clippy::too_many_arguments)]
#[args(
schema = "None",
has_header = "true",
delimiter = "\",\"",
schema_infer_max_records = "1000",
file_extension = "\".csv\"",
table_partition_cols = "vec![]"
)]
fn read_csv(
&self,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
"Delimiter must be a single character",
));
};
let mut options = CsvReadOptions::new()
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension)
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
if let Some(py_schema) = schema {
options.schema = Some(&py_schema.0);
let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
} else {
let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
}
}
#[allow(clippy::too_many_arguments)]
#[args(
parquet_pruning = "true",
file_extension = "\".parquet\"",
table_partition_cols = "vec![]",
skip_metadata = "true"
)]
fn read_parquet(
&self,
path: &str,
table_partition_cols: Vec<(String, String)>,
parquet_pruning: bool,
file_extension: &str,
skip_metadata: bool,
py: Python,
) -> PyResult<PyDataFrame> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.parquet_pruning(parquet_pruning)
.skip_metadata(skip_metadata);
options.file_extension = file_extension;
let result = self.ctx.read_parquet(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
}
#[allow(clippy::too_many_arguments)]
#[args(
schema = "None",
file_extension = "\".avro\"",
table_partition_cols = "vec![]"
)]
fn read_avro(
&self,
path: &str,
schema: Option<PyArrowType<Schema>>,
table_partition_cols: Vec<(String, String)>,
file_extension: &str,
py: Python,
) -> PyResult<PyDataFrame> {
let mut options = AvroReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
options.file_extension = file_extension;
options.schema = schema.map(|s| Arc::new(s.0));
let result = self.ctx.read_avro(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
}
}
fn convert_table_partition_cols(
table_partition_cols: Vec<(String, String)>,
) -> Result<Vec<(String, DataType)>, DataFusionError> {
table_partition_cols
.into_iter()
.map(|(name, ty)| match ty.as_str() {
"string" => Ok((name, DataType::Utf8)),
_ => Err(DataFusionError::Common(format!(
"Unsupported data type '{}' for partition column",
ty
))),
})
.collect::<Result<Vec<_>, _>>()
}