blob: 9f46917a87eb3bf6b8aa403dacc86e0dc8d4b105 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! array function utils
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Fields};
use arrow::array::{
Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar,
};
use arrow::buffer::OffsetBuffer;
use datafusion_common::cast::{
as_fixed_size_list_array, as_large_list_array, as_list_array,
};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
use datafusion_expr::ColumnarValue;
use itertools::Itertools as _;
pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
let data_type = args[0].data_type();
if !args.iter().all(|arg| {
arg.data_type().equals_datatype(data_type)
|| arg.data_type().equals_datatype(&DataType::Null)
}) {
let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
return plan_err!(
"{name} received incompatible types: {}",
types.iter().join(", ")
);
}
Ok(())
}
/// array function wrapper that differentiates between scalar (length 1) and array.
pub(crate) fn make_scalar_function<F>(
inner: F,
) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
{
move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});
let is_scalar = len.is_none();
let args = ColumnarValue::values_to_arrays(args)?;
let result = (inner)(&args);
if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
}
}
pub(crate) fn align_array_dimensions<O: OffsetSizeTrait>(
args: Vec<ArrayRef>,
) -> Result<Vec<ArrayRef>> {
let args_ndim = args
.iter()
.map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
.collect::<Vec<_>>();
let max_ndim = args_ndim.iter().max().unwrap_or(&0);
// Align the dimensions of the arrays
let aligned_args: Result<Vec<ArrayRef>> = args
.into_iter()
.zip(args_ndim.iter())
.map(|(array, ndim)| {
if ndim < max_ndim {
let mut aligned_array = Arc::clone(&array);
for _ in 0..(max_ndim - ndim) {
let data_type = aligned_array.data_type().to_owned();
let array_lengths = vec![1; aligned_array.len()];
let offsets = OffsetBuffer::<O>::from_lengths(array_lengths);
aligned_array = Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(data_type, true)),
offsets,
aligned_array,
None,
)?)
}
Ok(aligned_array)
} else {
Ok(Arc::clone(&array))
}
})
.collect();
aligned_args
}
/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
///
/// # Arguments
///
/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
///
/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
///
/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
///
/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
///
/// # Returns
///
/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
///
/// # Example
///
/// ```text
/// compare_element_to_list(
/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
///
/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
/// )
/// ```
pub(crate) fn compare_element_to_list(
list_array_row: &dyn Array,
element_array: &dyn Array,
row_index: usize,
eq: bool,
) -> Result<BooleanArray> {
if list_array_row.data_type() != element_array.data_type() {
return exec_err!(
"compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
list_array_row.data_type(),
element_array.data_type()
);
}
let element_array_row = element_array.slice(row_index, 1);
// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let res = match element_array_row.data_type() {
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
// compare each element of the from array
let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_list_array(list_array_row)?;
list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| {
row.map(|row| {
if eq {
row.eq(&element_array_row_inner)
} else {
row.ne(&element_array_row_inner)
}
})
})
.collect::<BooleanArray>()
}
DataType::LargeList(_) => {
// compare each element of the from array
let element_array_row_inner =
as_large_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_large_list_array(list_array_row)?;
list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| {
row.map(|row| {
if eq {
row.eq(&element_array_row_inner)
} else {
row.ne(&element_array_row_inner)
}
})
})
.collect::<BooleanArray>()
}
_ => {
let element_arr = Scalar::new(element_array_row);
// use not_distinct so we can compare NULL
if eq {
arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
} else {
arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
}
}
};
Ok(res)
}
/// Returns the length of each array dimension
pub(crate) fn compute_array_dims(
arr: Option<ArrayRef>,
) -> Result<Option<Vec<Option<u64>>>> {
let mut value = match arr {
Some(arr) => arr,
None => return Ok(None),
};
if value.is_empty() {
return Ok(None);
}
let mut res = vec![Some(value.len() as u64)];
loop {
match value.data_type() {
DataType::List(_) => {
value = as_list_array(&value)?.value(0);
res.push(Some(value.len() as u64));
}
DataType::LargeList(_) => {
value = as_large_list_array(&value)?.value(0);
res.push(Some(value.len() as u64));
}
DataType::FixedSizeList(..) => {
value = as_fixed_size_list_array(&value)?.value(0);
res.push(Some(value.len() as u64));
}
_ => return Ok(Some(res)),
}
}
}
pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
match data_type {
DataType::Map(field, _) => {
let field_data_type = field.data_type();
match field_data_type {
DataType::Struct(fields) => Ok(fields),
_ => {
internal_err!("Expected a Struct type, got {}", field_data_type)
}
}
}
_ => internal_err!("Expected a Map type, got {data_type}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::ListArray;
use arrow::datatypes::Int64Type;
use datafusion_common::utils::SingleRowListArrayBuilder;
/// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt`
#[test]
fn test_align_array_dimensions() {
let array1d_1: ArrayRef =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
]));
let array1d_2: ArrayRef =
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(vec![Some(6), Some(7), Some(8)]),
]));
let array2d_1: ArrayRef = Arc::new(
SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(),
);
let array2d_2 = Arc::new(
SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(),
);
let res = align_array_dimensions::<i32>(vec![
array1d_1.to_owned(),
array2d_2.to_owned(),
])
.unwrap();
let expected = as_list_array(&array2d_1).unwrap();
let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
assert_eq!(
datafusion_common::utils::list_ndims(res[0].data_type()),
expected_dim
);
let array3d_1: ArrayRef =
Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array());
let array3d_2: ArrayRef =
Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array());
let res = align_array_dimensions::<i32>(vec![array1d_1, array3d_2]).unwrap();
let expected = as_list_array(&array3d_1).unwrap();
let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
assert_eq!(
datafusion_common::utils::list_ndims(res[0].data_type()),
expected_dim
);
}
}