Add ability to execute ExecutionPlan and get a stream of RecordBatch (#186)
diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py
index dcab86a..1894688 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -388,6 +388,15 @@
assert "RepartitionExec:" in indent
assert "CsvExec:" in indent
+ ctx = SessionContext()
+ stream = ctx.execute(plan, 0)
+ # get the one and only batch
+ batch = stream.next()
+ assert batch is not None
+ # there should be no more batches
+ batch = stream.next()
+ assert batch is None
+
def test_repartition(df):
df.repartition(2)
diff --git a/src/context.rs b/src/context.rs
index 8dcd1d6..1acf5f2 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -28,7 +28,9 @@
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
-use crate::errors::DataFusionError;
+use crate::errors::{py_datafusion_err, DataFusionError};
+use crate::physical_plan::PyExecutionPlan;
+use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
@@ -39,14 +41,17 @@
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::datasource::TableProvider;
use datafusion::datasource::MemTable;
-use datafusion::execution::context::{SessionConfig, SessionContext};
+use datafusion::execution::context::{SessionConfig, SessionContext, TaskContext};
use datafusion::execution::disk_manager::DiskManagerConfig;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_common::ScalarValue;
+use tokio::runtime::Runtime;
+use tokio::task::JoinHandle;
#[pyclass(name = "SessionConfig", module = "datafusion", subclass, unsendable)]
#[derive(Clone, Default)]
@@ -579,6 +584,30 @@
Err(err) => Ok(format!("Error: {:?}", err.to_string())),
}
}
+
+ /// Execute a partition of an execution plan and return a stream of record batches
+ pub fn execute(
+ &self,
+ plan: PyExecutionPlan,
+ part: usize,
+ py: Python,
+ ) -> PyResult<PyRecordBatchStream> {
+ let ctx = Arc::new(TaskContext::new(
+ "task_id".to_string(),
+ "session_id".to_string(),
+ HashMap::new(),
+ HashMap::new(),
+ HashMap::new(),
+ Arc::new(RuntimeEnv::default()),
+ ));
+ // create a Tokio runtime to run the async code
+ let rt = Runtime::new().unwrap();
+ let plan = plan.plan.clone();
+ let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
+ rt.spawn(async move { plan.execute(part, ctx) });
+ let stream = wait_for_future(py, fut).map_err(|e| py_datafusion_err(e))?;
+ Ok(PyRecordBatchStream::new(stream?))
+ }
}
impl PySessionContext {
diff --git a/src/lib.rs b/src/lib.rs
index f6d404e..d9898db 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -37,6 +37,7 @@
mod functions;
pub mod physical_plan;
mod pyarrow_filter_expression;
+mod record_batch;
pub mod sql;
pub mod store;
pub mod substrait;
diff --git a/src/record_batch.rs b/src/record_batch.rs
new file mode 100644
index 0000000..15b70e8
--- /dev/null
+++ b/src/record_batch.rs
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::utils::wait_for_future;
+use datafusion::arrow::pyarrow::PyArrowConvert;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::physical_plan::SendableRecordBatchStream;
+use futures::StreamExt;
+use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
+
+#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
+pub struct PyRecordBatch {
+ batch: RecordBatch,
+}
+
+#[pymethods]
+impl PyRecordBatch {
+ fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
+ self.batch.to_pyarrow(py)
+ }
+}
+
+impl From<RecordBatch> for PyRecordBatch {
+ fn from(batch: RecordBatch) -> Self {
+ Self { batch }
+ }
+}
+
+#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
+pub struct PyRecordBatchStream {
+ stream: SendableRecordBatchStream,
+}
+
+impl PyRecordBatchStream {
+ pub fn new(stream: SendableRecordBatchStream) -> Self {
+ Self { stream }
+ }
+}
+
+#[pymethods]
+impl PyRecordBatchStream {
+ fn next(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
+ let result = self.stream.next();
+ match wait_for_future(py, result) {
+ None => Ok(None),
+ Some(Ok(b)) => Ok(Some(b.into())),
+ Some(Err(e)) => Err(e.into()),
+ }
+ }
+}