blob: 4cf0f8016ae28a69d26327e27b7eb69195a7bced [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow_array::{Array, ArrayRef, StructArray};
use datafusion_common::error::Result;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::ColumnarValue;
use sedona_common::sedona_internal_err;
use sedona_raster::array::{RasterRefImpl, RasterStructArray};
use sedona_schema::datatypes::SedonaType;
use sedona_schema::datatypes::RASTER;
/// Helper for writing raster kernel implementations
///
/// The [RasterExecutor] provides a simplified interface for executing functions
/// on raster arrays, handling the common pattern of downcasting to StructArray,
/// creating raster iterators, and handling null values.
pub struct RasterExecutor<'a, 'b> {
pub arg_types: &'a [SedonaType],
pub args: &'b [ColumnarValue],
num_iterations: usize,
}
impl<'a, 'b> RasterExecutor<'a, 'b> {
/// Create a new [RasterExecutor]
pub fn new(arg_types: &'a [SedonaType], args: &'b [ColumnarValue]) -> Self {
Self {
arg_types,
args,
num_iterations: Self::calc_num_iterations(args),
}
}
/// Return the number of iterations that will be performed
pub fn num_iterations(&self) -> usize {
self.num_iterations
}
/// Execute a function by iterating over rasters in the first argument
///
/// This handles the common pattern of:
/// 1. Downcasting array to StructArray
/// 2. Creating raster array
/// 3. Iterating with null checks
/// 4. Calling the provided function with each raster
pub fn execute_raster_void<F>(&self, mut func: F) -> Result<()>
where
F: FnMut(usize, Option<RasterRefImpl<'_>>) -> Result<()>,
{
if self.arg_types[0] != RASTER {
return sedona_internal_err!("First argument must be a raster type");
}
match &self.args[0] {
ColumnarValue::Array(array) => {
// Downcast to StructArray (rasters are stored as structs)
let raster_struct =
array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
DataFusionError::Internal(
"Expected StructArray for raster data".to_string(),
)
})?;
let raster_array = RasterStructArray::new(raster_struct);
// Iterate through each raster in the array
for i in 0..self.num_iterations {
if raster_array.is_null(i) {
func(i, None)?;
continue;
}
let raster = raster_array.get(i)?;
func(i, Some(raster))?;
}
Ok(())
}
ColumnarValue::Scalar(scalar_value) => match scalar_value {
ScalarValue::Struct(arc_struct) => {
let raster_array = RasterStructArray::new(arc_struct.as_ref());
let raster = raster_array.get(0)?;
func(0, Some(raster))
}
ScalarValue::Null => func(0, None),
_ => Err(DataFusionError::Internal(
"Expected Struct scalar for raster".to_string(),
)),
},
}
}
/// Finish an [ArrayRef] output as the appropriate [ColumnarValue]
///
/// Converts the output into a [ColumnarValue::Scalar] if all arguments were scalars,
/// or a [ColumnarValue::Array] otherwise.
pub fn finish(&self, out: ArrayRef) -> Result<ColumnarValue> {
for arg in self.args {
match arg {
// If any argument was an array, we return an array
ColumnarValue::Array(_) => {
return Ok(ColumnarValue::Array(out));
}
ColumnarValue::Scalar(_) => {}
}
}
// All arguments are scalars, return a scalar
Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&out, 0)?))
}
/// Calculates the number of iterations that should happen based on the
/// argument ColumnarValue types
fn calc_num_iterations(args: &[ColumnarValue]) -> usize {
for arg in args {
match arg {
// If any argument is an array, iterate array.len() times
ColumnarValue::Array(array) => {
return array.len();
}
ColumnarValue::Scalar(_) => {}
}
}
// All arguments are scalars, iterate once
1
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::UInt64Builder;
use arrow_array::UInt64Array;
use sedona_raster::traits::RasterRef;
use sedona_schema::datatypes::RASTER;
use sedona_testing::rasters::generate_test_rasters;
use std::sync::Arc;
#[test]
fn test_raster_executor_execute_raster_void() {
// 3 rasters, second one is null
let rasters = generate_test_rasters(3, Some(1)).unwrap();
let args = [ColumnarValue::Array(Arc::new(rasters))];
let arg_types = vec![RASTER];
let executor = RasterExecutor::new(&arg_types, &args);
assert_eq!(executor.num_iterations(), 3);
let mut builder = UInt64Builder::with_capacity(executor.num_iterations());
executor
.execute_raster_void(|_i, raster_opt| {
match raster_opt {
None => builder.append_null(),
Some(raster) => {
let width = raster.metadata().width();
builder.append_value(width);
}
}
Ok(())
})
.unwrap();
let result = executor.finish(Arc::new(builder.finish())).unwrap();
let width_array = match &result {
ColumnarValue::Array(array) => array
.as_any()
.downcast_ref::<UInt64Array>()
.expect("Expected UInt64Array"),
ColumnarValue::Scalar(_) => panic!("Expected array, got scalar"),
};
assert_eq!(width_array.len(), 3);
assert_eq!(width_array.value(0), 1);
assert!(width_array.is_null(1));
assert_eq!(width_array.value(2), 3);
}
#[test]
fn test_raster_executor_scalar_input() {
let rasters = generate_test_rasters(1, None).unwrap();
let raster_struct = rasters.as_any().downcast_ref::<StructArray>().unwrap();
let scalar_raster = ScalarValue::Struct(Arc::new(raster_struct.clone()));
let args = [ColumnarValue::Scalar(scalar_raster)];
let arg_types = vec![RASTER];
let executor = RasterExecutor::new(&arg_types, &args);
assert_eq!(executor.num_iterations(), 1);
let mut builder = UInt64Builder::with_capacity(executor.num_iterations());
executor
.execute_raster_void(|_i, raster_opt| {
match raster_opt {
None => builder.append_null(),
Some(raster) => {
let width = raster.metadata().width();
builder.append_value(width);
}
}
Ok(())
})
.unwrap();
let result = executor.finish(Arc::new(builder.finish())).unwrap();
// With scalar input, result should be a scalar
let width_scalar = match &result {
ColumnarValue::Scalar(scalar) => scalar,
ColumnarValue::Array(_) => panic!("Expected scalar, got array"),
};
match width_scalar {
ScalarValue::UInt64(Some(width)) => assert_eq!(*width, 1),
_ => panic!("Expected UInt64 scalar"),
}
}
#[test]
fn test_raster_executor_null_scalar() {
// Test with a null scalar
let args = [ColumnarValue::Scalar(ScalarValue::Null)];
let arg_types = vec![RASTER];
let executor = RasterExecutor::new(&arg_types, &args);
assert_eq!(executor.num_iterations(), 1);
let mut builder = UInt64Builder::with_capacity(executor.num_iterations());
executor
.execute_raster_void(|_i, raster_opt| {
match raster_opt {
None => builder.append_null(),
Some(raster) => {
let width = raster.metadata().width();
builder.append_value(width);
}
}
Ok(())
})
.unwrap();
let result = executor.finish(Arc::new(builder.finish())).unwrap();
// With null scalar input, result should be null scalar
let width_scalar = match &result {
ColumnarValue::Scalar(scalar) => scalar,
ColumnarValue::Array(_) => panic!("Expected scalar, got array"),
};
assert_eq!(width_scalar, &ScalarValue::UInt64(None));
}
}