fix(rust/sedona-geoparquet): Don't use ProjectionExec to create GeoParquet 1.1 bounding box columns (#398)

diff --git a/rust/sedona-geoparquet/src/writer.rs b/rust/sedona-geoparquet/src/writer.rs
index 2b2ec9a..8ea1a26 100644
--- a/rust/sedona-geoparquet/src/writer.rs
+++ b/rust/sedona-geoparquet/src/writer.rs
@@ -15,28 +15,35 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::{collections::HashMap, sync::Arc};
+use std::{any::Any, collections::HashMap, fmt, sync::Arc};
 
 use arrow_array::{
     builder::{Float32Builder, NullBufferBuilder},
-    ArrayRef, StructArray,
+    ArrayRef, RecordBatch, StructArray,
 };
-use arrow_schema::{DataType, Field, Fields};
+use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
+use async_trait::async_trait;
 use datafusion::{
     config::TableParquetOptions,
     datasource::{
-        file_format::parquet::ParquetSink, physical_plan::FileSinkConfig, sink::DataSinkExec,
+        file_format::parquet::ParquetSink,
+        physical_plan::FileSinkConfig,
+        sink::{DataSink, DataSinkExec},
     },
 };
 use datafusion_common::{
     config::ConfigOptions, exec_datafusion_err, exec_err, not_impl_err, DataFusionError, Result,
 };
+use datafusion_execution::{SendableRecordBatchStream, TaskContext};
 use datafusion_expr::{dml::InsertOp, ColumnarValue, ScalarUDF, Volatility};
 use datafusion_physical_expr::{
     expressions::Column, LexRequirement, PhysicalExpr, ScalarFunctionExpr,
 };
-use datafusion_physical_plan::{projection::ProjectionExec, ExecutionPlan};
+use datafusion_physical_plan::{
+    stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
+};
 use float_next_after::NextAfter;
+use futures::StreamExt;
 use geo_traits::GeometryTrait;
 use sedona_common::sedona_internal_err;
 use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
@@ -58,7 +65,7 @@
 };
 
 pub fn create_geoparquet_writer_physical_plan(
-    mut input: Arc<dyn ExecutionPlan>,
+    input: Arc<dyn ExecutionPlan>,
     mut conf: FileSinkConfig,
     order_requirements: Option<LexRequirement>,
     options: &TableGeoParquetOptions,
@@ -76,6 +83,8 @@
     // We have geometry and/or geography! Collect the GeoParquetMetadata we'll need to write
     let mut metadata = GeoParquetMetadata::default();
     let mut bbox_columns = HashMap::new();
+    let mut bbox_projection = None;
+    let mut parquet_output_schema = conf.output_schema().clone();
 
     // Check the version
     match options.geoparquet_version {
@@ -84,9 +93,10 @@
         }
         GeoParquetVersion::V1_1 => {
             metadata.version = "1.1.0".to_string();
-            (input, bbox_columns) = project_bboxes(input, options.overwrite_bbox_columns)?;
-            conf.output_schema = input.schema();
-            output_geometry_column_indices = input.schema().geometry_column_indices()?;
+            (bbox_projection, bbox_columns) =
+                project_bboxes(&input, options.overwrite_bbox_columns)?;
+            parquet_output_schema = compute_final_schema(&bbox_projection, &input.schema())?;
+            output_geometry_column_indices = conf.output_schema.geometry_column_indices()?;
         }
         _ => {
             return not_impl_err!(
@@ -168,10 +178,78 @@
     );
 
     // Create the sink
-    let sink = Arc::new(ParquetSink::new(conf, parquet_options));
+    let sink_input_schema = conf.output_schema;
+    conf.output_schema = parquet_output_schema.clone();
+    let sink = Arc::new(GeoParquetSink {
+        inner: ParquetSink::new(conf, parquet_options),
+        projection: bbox_projection,
+        sink_input_schema,
+        parquet_output_schema,
+    });
     Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
 }
 
+/// Implementation of [DataSink] that computes GeoParquet 1.1 bbox columns
+/// if needed. This is used instead of a ProjectionExec because DataFusion's
+/// optimizer rules seem to rearrange the projection in ways that cause
+/// the plan to fail <https://github.com/apache/sedona-db/issues/379>.
+#[derive(Debug)]
+struct GeoParquetSink {
+    inner: ParquetSink,
+    projection: Option<Vec<(Arc<dyn PhysicalExpr>, String)>>,
+    sink_input_schema: SchemaRef,
+    parquet_output_schema: SchemaRef,
+}
+
+impl DisplayAs for GeoParquetSink {
+    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        self.inner.fmt_as(t, f)
+    }
+}
+
+#[async_trait]
+impl DataSink for GeoParquetSink {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> &SchemaRef {
+        &self.sink_input_schema
+    }
+
+    async fn write_all(
+        &self,
+        data: SendableRecordBatchStream,
+        context: &Arc<TaskContext>,
+    ) -> Result<u64> {
+        if let Some(projection) = &self.projection {
+            // If we have a projection, apply it here
+            let schema = self.parquet_output_schema.clone();
+            let projection = projection.clone();
+
+            let data = Box::pin(RecordBatchStreamAdapter::new(
+                schema.clone(),
+                data.map(move |batch_result| {
+                    let schema = schema.clone();
+
+                    batch_result.and_then(|batch| {
+                        let mut columns = Vec::with_capacity(projection.len());
+                        for (expr, _) in &projection {
+                            let col = expr.evaluate(&batch)?;
+                            columns.push(col.into_array(batch.num_rows())?);
+                        }
+                        Ok(RecordBatch::try_new(schema.clone(), columns)?)
+                    })
+                }),
+            ));
+
+            self.inner.write_all(data, context).await
+        } else {
+            self.inner.write_all(data, context).await
+        }
+    }
+}
+
 /// Create a regular Parquet writer like DataFusion would otherwise do.
 fn create_inner_writer(
     input: Arc<dyn ExecutionPlan>,
@@ -184,6 +262,11 @@
     Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
 }
 
+type ProjectBboxesResult = (
+    Option<Vec<(Arc<dyn PhysicalExpr>, String)>>,
+    HashMap<String, String>,
+);
+
 /// Create a projection that inserts a bbox column for every geometry column
 ///
 /// This implements creating the GeoParquet 1.1 bounding box columns,
@@ -206,9 +289,9 @@
 /// "some_col_bbox", it is unlikely that replacing it would have unintended
 /// consequences.
 fn project_bboxes(
-    input: Arc<dyn ExecutionPlan>,
+    input: &Arc<dyn ExecutionPlan>,
     overwrite_bbox_columns: bool,
-) -> Result<(Arc<dyn ExecutionPlan>, HashMap<String, String>)> {
+) -> Result<ProjectBboxesResult> {
     let input_schema = input.schema();
     let matcher = ArgMatcher::is_geometry();
     let bbox_udf: Arc<ScalarUDF> = Arc::new(geoparquet_bbox_udf().into());
@@ -245,7 +328,7 @@
     // If we don't need to create any bbox columns, don't add an additional
     // projection at the end of the input plan
     if bbox_exprs.is_empty() {
-        return Ok((input, HashMap::new()));
+        return Ok((None, HashMap::new()));
     }
 
     // Create the projection expressions
@@ -275,13 +358,34 @@
         exprs.push((column, f.name().clone()));
     }
 
-    // Create the projection
-    let exec = ProjectionExec::try_new(exprs, input)?;
-
     // Flip the bbox_column_names into the form our caller needs it
     let bbox_column_names_by_field = bbox_column_names.drain().map(|(k, v)| (v, k)).collect();
 
-    Ok((Arc::new(exec), bbox_column_names_by_field))
+    Ok((Some(exprs), bbox_column_names_by_field))
+}
+
+fn compute_final_schema(
+    bbox_projection: &Option<Vec<(Arc<dyn PhysicalExpr>, String)>>,
+    initial_schema: &SchemaRef,
+) -> Result<SchemaRef> {
+    if let Some(bbox_projection) = bbox_projection {
+        let new_fields = bbox_projection
+            .iter()
+            .map(|(expr, name)| -> Result<Field> {
+                let return_field_ref = expr.return_field(initial_schema)?;
+                Ok(Field::new(
+                    name,
+                    return_field_ref.data_type().clone(),
+                    return_field_ref.is_nullable(),
+                )
+                .with_metadata(return_field_ref.metadata().clone()))
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        Ok(Arc::new(Schema::new(new_fields)))
+    } else {
+        Ok(initial_schema.clone())
+    }
 }
 
 fn geoparquet_bbox_udf() -> SedonaScalarUDF {
@@ -419,7 +523,7 @@
     };
     use datafusion_common::cast::{as_float32_array, as_struct_array};
     use datafusion_common::ScalarValue;
-    use datafusion_expr::{Expr, LogicalPlanBuilder};
+    use datafusion_expr::{Cast, Expr, LogicalPlanBuilder};
     use sedona_schema::datatypes::WKB_GEOMETRY;
     use sedona_testing::create::create_array;
     use sedona_testing::data::test_geoparquet;
@@ -745,6 +849,60 @@
             .unwrap();
     }
 
+    #[tokio::test]
+    async fn geoparquet_1_1_with_sort_by_expr() {
+        let example = test_geoparquet("ns-water", "water-point");
+
+        // Requires submodules/download-assets.py which not all contributors need
+        let example = match example {
+            Ok(path) => path,
+            Err(err) => {
+                println!("ns-water/water-point is not available: {err}");
+                return;
+            }
+        };
+
+        let ctx = setup_context();
+        let fns = sedona_functions::register::default_function_set();
+
+        let geometry_udf: ScalarUDF = fns.scalar_udf("sd_format").unwrap().clone().into();
+        let bbox_udf: ScalarUDF = geoparquet_bbox_udf().into();
+
+        let df = ctx
+            .table(&example)
+            .await
+            .unwrap()
+            .sort_by(vec![geometry_udf.call(vec![col("geometry")])])
+            .unwrap()
+            .select(vec![
+                Expr::Cast(Cast::new(
+                    geometry_udf.call(vec![col("geometry")]).alias("txt").into(),
+                    DataType::Utf8View,
+                )),
+                col("geometry"),
+            ])
+            .unwrap();
+
+        let mut options = TableGeoParquetOptions::new();
+        options.geoparquet_version = GeoParquetVersion::V1_1;
+
+        let df_batches_with_bbox = df
+            .clone()
+            .select(vec![
+                col("txt"),
+                bbox_udf.call(vec![col("geometry")]).alias("bbox"),
+                col("geometry"),
+            ])
+            .unwrap()
+            .collect()
+            .await
+            .unwrap();
+
+        test_write_dataframe(ctx, df, df_batches_with_bbox, options, vec![])
+            .await
+            .unwrap();
+    }
+
     #[test]
     fn float_bbox() {
         let tester = ScalarUdfTester::new(geoparquet_bbox_udf().into(), vec![WKB_GEOMETRY]);