blob: 08fb78905633a1295862e86d295f510a1f3a1d26 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow::{
array::{as_primitive_array, Capacities, MutableArrayData},
buffer::{NullBuffer, OffsetBuffer},
datatypes::ArrowNativeType,
record_batch::RecordBatch,
};
use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
use arrow_schema::{DataType, Field, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{
cast::{as_large_list_array, as_list_array},
internal_err, DataFusionError, Result as DataFusionResult,
};
use datafusion_physical_expr::PhysicalExpr;
use std::hash::Hash;
use std::{
any::Any,
fmt::{Debug, Display, Formatter},
sync::Arc,
};
// 2147483632 == java.lang.Integer.MAX_VALUE - 15
// It is a value of ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH
// https://github.com/apache/spark/blob/master/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java
const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632;
#[derive(Debug, Eq)]
pub struct ArrayInsert {
src_array_expr: Arc<dyn PhysicalExpr>,
pos_expr: Arc<dyn PhysicalExpr>,
item_expr: Arc<dyn PhysicalExpr>,
legacy_negative_index: bool,
}
impl Hash for ArrayInsert {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.src_array_expr.hash(state);
self.pos_expr.hash(state);
self.item_expr.hash(state);
self.legacy_negative_index.hash(state);
}
}
impl PartialEq for ArrayInsert {
fn eq(&self, other: &Self) -> bool {
self.src_array_expr.eq(&other.src_array_expr)
&& self.pos_expr.eq(&other.pos_expr)
&& self.item_expr.eq(&other.item_expr)
&& self.legacy_negative_index.eq(&other.legacy_negative_index)
}
}
impl ArrayInsert {
pub fn new(
src_array_expr: Arc<dyn PhysicalExpr>,
pos_expr: Arc<dyn PhysicalExpr>,
item_expr: Arc<dyn PhysicalExpr>,
legacy_negative_index: bool,
) -> Self {
Self {
src_array_expr,
pos_expr,
item_expr,
legacy_negative_index,
}
}
pub fn array_type(&self, data_type: &DataType) -> DataFusionResult<DataType> {
match data_type {
DataType::List(field) => Ok(DataType::List(Arc::clone(field))),
DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))),
data_type => Err(DataFusionError::Internal(format!(
"Unexpected src array type in ArrayInsert: {:?}",
data_type
))),
}
}
}
impl PhysicalExpr for ArrayInsert {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
self.array_type(&self.src_array_expr.data_type(input_schema)?)
}
fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
self.src_array_expr.nullable(input_schema)
}
fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let pos_value = self
.pos_expr
.evaluate(batch)?
.into_array(batch.num_rows())?;
// Spark supports only IntegerType (Int32):
// https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4737
if !matches!(pos_value.data_type(), DataType::Int32) {
return Err(DataFusionError::Internal(format!(
"Unexpected index data type in ArrayInsert: {:?}, expected type is Int32",
pos_value.data_type()
)));
}
// Check that src array is actually an array and get it's value type
let src_value = self
.src_array_expr
.evaluate(batch)?
.into_array(batch.num_rows())?;
let src_element_type = match self.array_type(src_value.data_type())? {
DataType::List(field) => &field.data_type().clone(),
DataType::LargeList(field) => &field.data_type().clone(),
_ => unreachable!(),
};
// Check that inserted value has the same type as an array
let item_value = self
.item_expr
.evaluate(batch)?
.into_array(batch.num_rows())?;
if item_value.data_type() != src_element_type {
return Err(DataFusionError::Internal(format!(
"Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}",
src_element_type,
item_value.data_type()
)));
}
match src_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&src_value)?;
array_insert(
list_array,
&item_value,
&pos_value,
self.legacy_negative_index,
)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&src_value)?;
array_insert(
list_array,
&item_value,
&pos_value,
self.legacy_negative_index,
)
}
_ => unreachable!(), // This case is checked already
}
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.src_array_expr, &self.pos_expr, &self.item_expr]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
match children.len() {
3 => Ok(Arc::new(ArrayInsert::new(
Arc::clone(&children[0]),
Arc::clone(&children[1]),
Arc::clone(&children[2]),
self.legacy_negative_index,
))),
_ => internal_err!("ArrayInsert should have exactly three childrens"),
}
}
}
fn array_insert<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
items_array: &ArrayRef,
pos_array: &ArrayRef,
legacy_mode: bool,
) -> DataFusionResult<ColumnarValue> {
// The code is based on the implementation of the array_append from the Apache DataFusion
// https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
//
// This code is also based on the implementation of the array_insert from the Apache Spark
// https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713
let values = list_array.values();
let offsets = list_array.offsets();
let values_data = values.to_data();
let item_data = items_array.to_data();
let new_capacity = Capacities::Array(values_data.len() + item_data.len());
let mut mutable_values =
MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);
let mut new_offsets = vec![O::usize_as(0)];
let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions
for (row_index, offset_window) in offsets.windows(2).enumerate() {
let pos = pos_data.values()[row_index];
let start = offset_window[0].as_usize();
let end = offset_window[1].as_usize();
let is_item_null = items_array.is_null(row_index);
if list_array.is_null(row_index) {
// In Spark if value of the array is NULL than nothing happens
mutable_values.extend_nulls(1);
new_offsets.push(new_offsets[row_index] + O::one());
new_nulls.push(false);
continue;
}
if pos == 0 {
return Err(DataFusionError::Internal(
"Position for array_insert should be greter or less than zero".to_string(),
));
}
if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
let corrected_pos = if pos > 0 {
(pos - 1).as_usize()
} else {
end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
};
let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {:?}, but got {:?}",
MAX_ROUNDED_ARRAY_LENGTH, new_array_len
)));
}
if (start + corrected_pos) <= end {
mutable_values.extend(0, start, start + corrected_pos);
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend(0, start + corrected_pos, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
} else {
mutable_values.extend(0, start, end);
mutable_values.extend_nulls(new_array_len - (end - start));
mutable_values.extend(1, row_index, row_index + 1);
// In that case spark actualy makes array longer than expected;
// For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
}
} else {
// This comment is takes from the Apache Spark source code as is:
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null
let base_offset = if legacy_mode { 1 } else { 0 };
let new_array_len = (-pos + base_offset).as_usize();
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {:?}, but got {:?}",
MAX_ROUNDED_ARRAY_LENGTH, new_array_len
)));
}
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend_nulls(new_array_len - (end - start + 1));
mutable_values.extend(0, start, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
}
if is_item_null {
if (start == end) || (values.is_null(row_index)) {
new_nulls.push(false)
} else {
new_nulls.push(true)
}
} else {
new_nulls.push(true)
}
}
let data = make_array(mutable_values.freeze());
let data_type = match list_array.data_type() {
DataType::List(field) => field.data_type(),
DataType::LargeList(field) => field.data_type(),
_ => unreachable!(),
};
let new_array = GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.clone(), true)),
OffsetBuffer::new(new_offsets.into()),
data,
Some(NullBuffer::new(new_nulls.into())),
)?;
Ok(ColumnarValue::Array(Arc::new(new_array)))
}
impl Display for ArrayInsert {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ArrayInsert [array: {:?}, pos: {:?}, item: {:?}]",
self.src_array_expr, self.pos_expr, self.item_expr
)
}
}
#[cfg(test)]
mod test {
use super::*;
use arrow::datatypes::Int32Type;
use arrow_array::{Array, ArrayRef, Int32Array, ListArray};
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;
#[test]
fn test_array_insert() -> Result<()> {
// Test inserting an item into a list array
// Inputs and expected values are taken from the Spark results
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
Some(vec![None]),
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(1), Some(2), Some(3)]),
None,
]);
let positions = Int32Array::from(vec![2, 1, 1, 5, 6, 1]);
let items = Int32Array::from(vec![
Some(10),
Some(20),
Some(30),
Some(100),
Some(100),
Some(40),
]);
let ColumnarValue::Array(result) = array_insert(
&list,
&(Arc::new(items) as ArrayRef),
&(Arc::new(positions) as ArrayRef),
false,
)?
else {
unreachable!()
};
let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(10), Some(2), Some(3)]),
Some(vec![Some(20), Some(4), Some(5)]),
Some(vec![Some(30), None]),
Some(vec![Some(1), Some(2), Some(3), None, Some(100)]),
Some(vec![Some(1), Some(2), Some(3), None, None, Some(100)]),
None,
]);
assert_eq!(&result.to_data(), &expected.to_data());
Ok(())
}
#[test]
fn test_array_insert_negative_index() -> Result<()> {
// Test insert with negative index
// Inputs and expected values are taken from the Spark results
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
Some(vec![Some(1)]),
None,
]);
let positions = Int32Array::from(vec![-2, -1, -3, -1]);
let items = Int32Array::from(vec![Some(10), Some(20), Some(100), Some(30)]);
let ColumnarValue::Array(result) = array_insert(
&list,
&(Arc::new(items) as ArrayRef),
&(Arc::new(positions) as ArrayRef),
false,
)?
else {
unreachable!()
};
let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(10), Some(3)]),
Some(vec![Some(4), Some(5), Some(20)]),
Some(vec![Some(100), None, Some(1)]),
None,
]);
assert_eq!(&result.to_data(), &expected.to_data());
Ok(())
}
#[test]
fn test_array_insert_legacy_mode() -> Result<()> {
// Test the so-called "legacy" mode exisiting in the Spark
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
None,
]);
let positions = Int32Array::from(vec![-1, -1, -1]);
let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
let ColumnarValue::Array(result) = array_insert(
&list,
&(Arc::new(items) as ArrayRef),
&(Arc::new(positions) as ArrayRef),
true,
)?
else {
unreachable!()
};
let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(10), Some(3)]),
Some(vec![Some(4), Some(20), Some(5)]),
None,
]);
assert_eq!(&result.to_data(), &expected.to_data());
Ok(())
}
}