feat(dataframe): add executeStream(allocator) for incremental batch iteration (#51)
diff --git a/core/src/main/java/org/apache/datafusion/DataFrame.java b/core/src/main/java/org/apache/datafusion/DataFrame.java
index c85692c..c19ce39 100644
--- a/core/src/main/java/org/apache/datafusion/DataFrame.java
+++ b/core/src/main/java/org/apache/datafusion/DataFrame.java
@@ -26,12 +26,14 @@
/**
* A lazy representation of a query plan, mirroring the Rust DataFusion {@code DataFrame}. Created
- * by {@link SessionContext#sql(String)} or other planning entry points and executed by {@link
- * #collect}.
+ * by {@link SessionContext#sql(String)} or other planning entry points and executed by either
+ * {@link #collect} (materializes every batch on the native heap before returning) or {@link
+ * #executeStream} (yields one batch at a time as Java drains the reader).
*
- * <p>Instances are <strong>not thread-safe</strong> and must be closed. {@link #collect} consumes
- * the DataFrame: a successfully collected DataFrame cannot be collected again, and {@link #close()}
- * on an already-collected instance is a no-op.
+ * <p>Instances are <strong>not thread-safe</strong> and must be closed. Both {@link #collect} and
+ * {@link #executeStream} consume the DataFrame: a successfully consumed DataFrame cannot be
+ * consumed again by either method (or by other executors such as {@link #count}), and {@link
+ * #close()} on an already-consumed instance is a no-op.
*/
public final class DataFrame implements AutoCloseable {
static {
@@ -53,6 +55,10 @@
* <p>Consumes this DataFrame: the native plan is released as soon as the stream is established.
* The caller is responsible for closing the returned reader, and the supplied allocator must
* outlive it.
+ *
+ * <p>This method materializes every batch on the native heap before the first batch crosses the
+ * FFI boundary, which can OOM the Rust side for unbounded or very large result sets. Prefer
+ * {@link #executeStream(BufferAllocator)} for analytics-scale queries.
*/
public ArrowReader collect(BufferAllocator allocator) {
if (nativeHandle == 0) {
@@ -70,6 +76,36 @@
}
}
+ /**
+ * Execute the plan and return its record batches as a streaming {@link ArrowReader}. Each call to
+ * {@link ArrowReader#loadNextBatch} drives one async {@code stream.next()} on the native side, so
+ * memory pressure stays bounded by the executor pipeline plus one in-flight batch instead of the
+ * full result set.
+ *
+ * <p>Consumes this DataFrame with the same lifecycle rules as {@link #collect(BufferAllocator)}:
+ * the native plan is released as soon as the stream is established, the caller closes the
+ * returned reader, and the supplied allocator must outlive it.
+ *
+ * <p>For result sets that fit comfortably in native memory and are read in their entirety, {@link
+ * #collect(BufferAllocator)} remains a reasonable choice. For TB-scale or unbounded result sets,
+ * use this method.
+ */
+ public ArrowReader executeStream(BufferAllocator allocator) {
+ if (nativeHandle == 0) {
+ throw new IllegalStateException("DataFrame is closed or already collected");
+ }
+ ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator);
+ long handle = nativeHandle;
+ nativeHandle = 0;
+ try {
+ executeStreamDataFrame(handle, stream.memoryAddress());
+ return Data.importArrayStream(allocator, stream);
+ } catch (Throwable e) {
+ stream.close();
+ throw e;
+ }
+ }
+
/** Execute the plan and return the number of rows. */
public long count() {
if (nativeHandle == 0) {
@@ -292,6 +328,8 @@
private static native void collectDataFrame(long handle, long ffiStreamAddr);
+ private static native void executeStreamDataFrame(long handle, long ffiStreamAddr);
+
private static native void closeDataFrame(long handle);
private static native long countRows(long handle);
diff --git a/core/src/test/java/org/apache/datafusion/DataFrameExecuteStreamTest.java b/core/src/test/java/org/apache/datafusion/DataFrameExecuteStreamTest.java
new file mode 100644
index 0000000..8a74257
--- /dev/null
+++ b/core/src/test/java/org/apache/datafusion/DataFrameExecuteStreamTest.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.junit.jupiter.api.Assumptions;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+class DataFrameExecuteStreamTest {
+
+ /**
+ * Write a CSV with `rows` integer rows, one column `x`. Used in tests that need a real file scan
+ * so DataFusion's batching honors {@code batch_size} -- in-memory {@code VALUES} plans get
+ * coalesced into a single batch in some DataFusion versions, which would make those tests
+ * brittle.
+ */
+ private static Path writeRowsCsv(Path dir, int rows) throws IOException {
+ StringBuilder sb = new StringBuilder("x\n");
+ for (int i = 1; i <= rows; i++) {
+ sb.append(i).append('\n');
+ }
+ Path file = dir.resolve("rows.csv");
+ Files.writeString(file, sb.toString());
+ return file;
+ }
+
+ @Test
+ void executeStreamYieldsTheSameRowsAsCollect() throws Exception {
+ String sql = "SELECT * FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)";
+
+ long collected = 0;
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = new SessionContext();
+ DataFrame df = ctx.sql(sql);
+ ArrowReader reader = df.collect(allocator)) {
+ while (reader.loadNextBatch()) {
+ collected += reader.getVectorSchemaRoot().getRowCount();
+ }
+ }
+
+ long streamed = 0;
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = new SessionContext();
+ DataFrame df = ctx.sql(sql);
+ ArrowReader reader = df.executeStream(allocator)) {
+ while (reader.loadNextBatch()) {
+ streamed += reader.getVectorSchemaRoot().getRowCount();
+ }
+ }
+
+ assertEquals(5L, collected);
+ assertEquals(collected, streamed);
+ }
+
+ @Test
+ void executeStreamConsumesTheDataFrame() throws Exception {
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = new SessionContext()) {
+ DataFrame df = ctx.sql("SELECT 1");
+ try (ArrowReader reader = df.executeStream(allocator)) {
+ assertTrue(reader.loadNextBatch());
+ }
+ // After a successful executeStream, the DataFrame's native handle is
+ // released. A second collect/executeStream/count must throw.
+ assertThrows(IllegalStateException.class, () -> df.executeStream(allocator));
+ assertThrows(IllegalStateException.class, () -> df.collect(allocator));
+ assertThrows(IllegalStateException.class, df::count);
+ // close() on an already-streamed DataFrame is a no-op (no double-free).
+ df.close();
+ }
+ }
+
+ @Test
+ void executeStreamReadsBatchByBatch(@TempDir Path tempDir) throws Exception {
+ // CSV with 5 rows scanned at batch_size=2 reliably yields multiple batches
+ // across DataFusion versions, where an in-memory VALUES plan can be
+ // coalesced into a single batch by the planner. The point of this test is
+ // to pin "executeStream actually streams" without coupling to planner
+ // batching behavior that may shift in upstream releases.
+ Path csv = writeRowsCsv(tempDir, 5);
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = SessionContext.builder().batchSize(2).build()) {
+ ctx.registerCsv("rows", csv.toAbsolutePath().toString());
+ try (DataFrame df = ctx.sql("SELECT x FROM rows");
+ ArrowReader reader = df.executeStream(allocator)) {
+ int batches = 0;
+ long total = 0;
+ int maxBatchSize = 0;
+ while (reader.loadNextBatch()) {
+ batches++;
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ total += root.getRowCount();
+ maxBatchSize = Math.max(maxBatchSize, root.getRowCount());
+ }
+ assertEquals(5L, total);
+ assertTrue(batches >= 2, "expected multiple batches with batchSize=2, got " + batches);
+ assertTrue(maxBatchSize <= 2, "expected each batch <= 2 rows, got " + maxBatchSize);
+ }
+ }
+ }
+
+ @Test
+ void executeStreamSurvivesEarlyClose() throws Exception {
+ // Close the reader after the first batch and confirm no native panic /
+ // resource leak. The DataFrame is already consumed; explicit close on it
+ // must remain a no-op.
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = SessionContext.builder().batchSize(1).build();
+ DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)");
+ ArrowReader reader = df.executeStream(allocator)) {
+ assertTrue(reader.loadNextBatch());
+ }
+ }
+
+ @Test
+ void executeStreamOverParquetMatchesCollectRowCount() throws Exception {
+ Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet");
+ Assumptions.assumeTrue(
+ Files.exists(lineitem), "TPC-H SF1 data not found; run `make tpch-data` first");
+
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = new SessionContext()) {
+ ctx.registerParquet("lineitem", lineitem.toAbsolutePath().toString());
+
+ long collected;
+ try (DataFrame df = ctx.sql("SELECT COUNT(*) FROM lineitem");
+ ArrowReader reader = df.collect(allocator)) {
+ assertTrue(reader.loadNextBatch());
+ BigIntVector v = (BigIntVector) reader.getVectorSchemaRoot().getVector(0);
+ collected = v.get(0);
+ }
+
+ long streamed = 0;
+ try (DataFrame df = ctx.sql("SELECT l_orderkey FROM lineitem");
+ ArrowReader reader = df.executeStream(allocator)) {
+ while (reader.loadNextBatch()) {
+ streamed += reader.getVectorSchemaRoot().getRowCount();
+ }
+ }
+ assertEquals(collected, streamed);
+ }
+ }
+
+ @Test
+ void executeStreamColumnValuesAreCorrect() throws Exception {
+ // Pin actual cell values, not just row counts: a regression that
+ // shipped wrong values per batch must be caught.
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = SessionContext.builder().batchSize(2).build();
+ DataFrame df = ctx.sql("SELECT * FROM (VALUES (10), (20), (30), (40)) AS t(x) ORDER BY x");
+ ArrowReader reader = df.executeStream(allocator)) {
+ java.util.List<Long> seen = new java.util.ArrayList<>();
+ while (reader.loadNextBatch()) {
+ BigIntVector v = (BigIntVector) reader.getVectorSchemaRoot().getVector(0);
+ for (int i = 0; i < v.getValueCount(); i++) {
+ seen.add(v.get(i));
+ }
+ }
+ assertEquals(java.util.List.of(10L, 20L, 30L, 40L), seen);
+ }
+ }
+}
diff --git a/native/Cargo.lock b/native/Cargo.lock
index 495cc60..bb9578f 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -1141,6 +1141,7 @@
"arrow",
"datafusion",
"datafusion-proto",
+ "futures",
"jni",
"prost",
"prost-build",
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 01dd002..b9fca20 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -28,6 +28,7 @@
arrow = { version = "58", features = ["ffi"] }
datafusion = "53.1.0"
datafusion-proto = "53.1.0"
+futures = "0.3"
jni = "0.21"
prost = "0.14"
tokio = { version = "1", features = ["rt-multi-thread"] }
diff --git a/native/src/lib.rs b/native/src/lib.rs
index dba3c08..1472628 100644
--- a/native/src/lib.rs
+++ b/native/src/lib.rs
@@ -27,20 +27,25 @@
include!(concat!(env!("OUT_DIR"), "/datafusion_java.rs"));
}
+use std::panic::{catch_unwind, AssertUnwindSafe};
use std::path::PathBuf;
use std::sync::{Arc, OnceLock};
+use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::arrow::error::ArrowError;
use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream;
-use datafusion::arrow::record_batch::RecordBatchIterator;
+use datafusion::arrow::record_batch::{RecordBatchIterator, RecordBatchReader};
use datafusion::common::UnnestOptions;
use datafusion::config::TableParquetOptions;
use datafusion::dataframe::DataFrame;
use datafusion::dataframe::DataFrameWriteOptions;
use datafusion::error::DataFusionError;
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
+use datafusion::execution::SendableRecordBatchStream;
use datafusion::logical_expr::{ScalarUDF, Signature};
use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
+use futures::StreamExt;
use jni::objects::{JByteArray, JClass, JObject, JObjectArray, JString};
use jni::sys::{jboolean, jint, jlong};
use jni::JNIEnv;
@@ -203,6 +208,80 @@
})
}
+/// Bridges DataFusion's async [`SendableRecordBatchStream`] to the synchronous
+/// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore
+/// the Java `ArrowReader`) consumes. Each call to `next()` drives one
+/// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the
+/// executor pipeline plus a single in-flight batch.
+struct StreamingReader {
+ schema: SchemaRef,
+ stream: SendableRecordBatchStream,
+}
+
+impl Iterator for StreamingReader {
+ type Item = Result<RecordBatch, ArrowError>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ // Arrow's C ABI invokes this iterator through FFI_ArrowArrayStream's
+ // vtable, outside the JNI handler's try_unwrap_or_throw guard. A panic
+ // here (buggy UDF, arrow cast that panics, runtime poison) would
+ // unwind across C/FFI -- undefined behaviour. Catch it and surface as
+ // an ArrowError so the Java side sees a normal exception instead.
+ let next = catch_unwind(AssertUnwindSafe(|| runtime().block_on(self.stream.next())));
+ match next {
+ Ok(item) => item.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))),
+ Err(panic) => {
+ let msg = if let Some(s) = panic.downcast_ref::<String>() {
+ s.clone()
+ } else if let Some(s) = panic.downcast_ref::<&str>() {
+ (*s).to_string()
+ } else {
+ "rust panic with non-string payload".to_string()
+ };
+ Some(Err(ArrowError::ExternalError(
+ format!("panic in DataFrame stream: {msg}").into(),
+ )))
+ }
+ }
+ }
+}
+
+impl RecordBatchReader for StreamingReader {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+#[no_mangle]
+pub extern "system" fn Java_org_apache_datafusion_DataFrame_executeStreamDataFrame<'local>(
+ mut env: JNIEnv<'local>,
+ _class: JClass<'local>,
+ handle: jlong,
+ ffi_stream_addr: jlong,
+) {
+ try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> {
+ if handle == 0 {
+ return Err("DataFrame handle is null".into());
+ }
+ if ffi_stream_addr == 0 {
+ return Err("ffi stream address is null".into());
+ }
+ let df = unsafe { *Box::from_raw(handle as *mut DataFrame) };
+
+ let ffi: FFI_ArrowArrayStream = runtime().block_on(async {
+ let schema: SchemaRef = Arc::new(df.schema().as_arrow().clone());
+ let stream = df.execute_stream().await?;
+ let reader = StreamingReader { schema, stream };
+ Ok::<_, DataFusionError>(FFI_ArrowArrayStream::new(Box::new(reader)))
+ })?;
+
+ unsafe {
+ std::ptr::write(ffi_stream_addr as *mut FFI_ArrowArrayStream, ffi);
+ }
+ Ok(())
+ })
+}
+
#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_countRows<'local>(
mut env: JNIEnv<'local>,