fix: column indices in FFI partition evaluator (#16480) (#16657)
* Column indices were not computed correctly, causing a panic
* Add unit tests
Co-authored-by: Tim Saucer <timsaucer@gmail.com>
diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs
index aaa3f5c..504bf7a 100644
--- a/datafusion/ffi/src/udwf/mod.rs
+++ b/datafusion/ffi/src/udwf/mod.rs
@@ -363,4 +363,70 @@
}
#[cfg(test)]
-mod tests {}
+#[cfg(feature = "integration-tests")]
+mod tests {
+ use crate::tests::create_record_batch;
+ use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
+ use arrow::array::{create_array, ArrayRef};
+ use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift};
+ use datafusion::logical_expr::expr::Sort;
+ use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl};
+ use datafusion::prelude::SessionContext;
+ use std::sync::Arc;
+
+ fn create_test_foreign_udwf(
+ original_udwf: impl WindowUDFImpl + 'static,
+ ) -> datafusion::common::Result<WindowUDF> {
+ let original_udwf = Arc::new(WindowUDF::from(original_udwf));
+
+ let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
+
+ let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
+ Ok(foreign_udwf.into())
+ }
+
+ #[test]
+ fn test_round_trip_udwf() -> datafusion::common::Result<()> {
+ let original_udwf = lag_udwf();
+ let original_name = original_udwf.name().to_owned();
+
+ // Convert to FFI format
+ let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
+
+ // Convert back to native format
+ let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
+ let foreign_udwf: WindowUDF = foreign_udwf.into();
+
+ assert_eq!(original_name, foreign_udwf.name());
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_lag_udwf() -> datafusion::common::Result<()> {
+ let udwf = create_test_foreign_udwf(WindowShift::lag())?;
+
+ let ctx = SessionContext::default();
+ let df = ctx.read_batch(create_record_batch(-5, 5))?;
+
+ let df = df.select(vec![
+ col("a"),
+ udwf.call(vec![col("a")])
+ .order_by(vec![Sort::new(col("a"), true, true)])
+ .build()
+ .unwrap()
+ .alias("lag_a"),
+ ])?;
+
+ df.clone().show().await?;
+
+ let result = df.collect().await?;
+ let expected =
+ create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
+ as ArrayRef;
+
+ assert_eq!(result.len(), 1);
+ assert_eq!(result[0].column(1), &expected);
+
+ Ok(())
+ }
+}
diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs
index e74d47a..dffeb23 100644
--- a/datafusion/ffi/src/udwf/partition_evaluator_args.rs
+++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs
@@ -75,17 +75,24 @@
})
.collect();
- let max_column = required_columns.keys().max().unwrap_or(&0).to_owned();
- let fields: Vec<_> = (0..max_column)
- .map(|idx| match required_columns.get(&idx) {
- Some((name, data_type)) => Field::new(*name, (*data_type).clone(), true),
- None => Field::new(
- format!("ffi_partition_evaluator_col_{idx}"),
- DataType::Null,
- true,
- ),
+ let max_column = required_columns.keys().max();
+ let fields: Vec<_> = max_column
+ .map(|max_column| {
+ (0..(max_column + 1))
+ .map(|idx| match required_columns.get(&idx) {
+ Some((name, data_type)) => {
+ Field::new(*name, (*data_type).clone(), true)
+ }
+ None => Field::new(
+ format!("ffi_partition_evaluator_col_{idx}"),
+ DataType::Null,
+ true,
+ ),
+ })
+ .collect()
})
- .collect();
+ .unwrap_or_default();
+
let schema = Arc::new(Schema::new(fields));
let codec = DefaultPhysicalExtensionCodec {};