allow config options to be passed to context (#76)
* allow config options to be passed to context
* fmt
* fix?
* remove println
* Update src/context.rs
Co-authored-by: Batuhan Taskaya <isidentical@gmail.com>
* address feedback
Co-authored-by: Batuhan Taskaya <isidentical@gmail.com>
diff --git a/Cargo.lock b/Cargo.lock
index db3a43e..75059dc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -664,6 +664,7 @@
"futures",
"mimalloc",
"object_store",
+ "parking_lot",
"pyo3",
"rand 0.7.3",
"tokio",
diff --git a/Cargo.toml b/Cargo.toml
index 0e8db9c..3f3eeb4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -43,6 +43,7 @@
async-trait = "0.1"
futures = "0.3"
object_store = { version = "0.5.1", features = ["aws", "gcp", "azure"] }
+parking_lot = "0.12"
[lib]
name = "datafusion_python"
diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 50bdf43..55849ed 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -39,6 +39,7 @@
repartition_aggregations=False,
repartition_windows=False,
parquet_pruning=False,
+ config_options=None,
)
# verify that at least some of the arguments worked
diff --git a/src/context.rs b/src/context.rs
index 93ee1c7..9f6ef30 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use std::collections::HashSet;
+use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
@@ -25,13 +25,7 @@
use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::prelude::*;
-use datafusion::arrow::datatypes::Schema;
-use datafusion::arrow::pyarrow::PyArrowType;
-use datafusion::arrow::record_batch::RecordBatch;
-use datafusion::datasource::datasource::TableProvider;
-use datafusion::datasource::MemTable;
-use datafusion::execution::context::{SessionConfig, SessionContext};
-use datafusion::prelude::{AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions};
+use parking_lot::RwLock;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
@@ -41,6 +35,14 @@
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::utils::wait_for_future;
+use datafusion::arrow::datatypes::Schema;
+use datafusion::arrow::pyarrow::PyArrowType;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::config::ConfigOptions;
+use datafusion::datasource::datasource::TableProvider;
+use datafusion::datasource::MemTable;
+use datafusion::execution::context::{SessionConfig, SessionContext};
+use datafusion::prelude::{AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions};
/// `PySessionContext` is able to plan and execute DataFusion plans.
/// It has a powerful optimizer, a physical planner for local execution, and a
@@ -62,7 +64,8 @@
repartition_aggregations = "true",
repartition_windows = "true",
parquet_pruning = "true",
- target_partitions = "None"
+ target_partitions = "None",
+ config_options = "None"
)]
#[new]
fn new(
@@ -75,9 +78,23 @@
repartition_windows: bool,
parquet_pruning: bool,
target_partitions: Option<usize>,
- // TODO: config_options
+ config_options: Option<HashMap<String, String>>,
) -> Self {
- let cfg = SessionConfig::new()
+ let mut options = ConfigOptions::from_env();
+ if let Some(hash_map) = config_options {
+ for (k, v) in &hash_map {
+ if let Ok(v) = v.parse::<bool>() {
+ options.set_bool(k, v);
+ } else if let Ok(v) = v.parse::<u64>() {
+ options.set_u64(k, v);
+ } else {
+ options.set_string(k, v);
+ }
+ }
+ }
+ let config_options = Arc::new(RwLock::new(options));
+
+ let mut cfg = SessionConfig::new()
.create_default_catalog_and_schema(create_default_catalog_and_schema)
.with_default_catalog_and_schema(default_catalog, default_schema)
.with_information_schema(information_schema)
@@ -86,6 +103,9 @@
.with_repartition_windows(repartition_windows)
.with_parquet_pruning(parquet_pruning);
+ // TODO we should add a `with_config_options` to `SessionConfig`
+ cfg.config_options = config_options;
+
let cfg_full = match target_partitions {
None => cfg,
Some(x) => cfg.with_target_partitions(x),