blob: 730b4b39ac1caf9d8b47c3df98b8b4f084f916c6 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.
use crate::utils;
use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder,
OffsetSizeTrait, cast::AsArray, make_array,
};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;
make_udf_expr_and_func!(
ArrayRemove,
array_remove,
array element,
"removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
array_remove_udf
);
#[user_doc(
doc_section(label = "Array Functions"),
description = "Removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
syntax_example = "array_remove(array, element)",
sql_example = r#"```sql
> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
+----------------------------------------------+
| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
+----------------------------------------------+
| [1, 2, 3, 2, 1, 4] |
+----------------------------------------------+
> select array_remove([1, 2, NULL, 2, 4], 2);
+---------------------------------------------------+
| array_remove(List([1,2,NULL,2,4]),Int64(2)) |
+---------------------------------------------------+
| [1, NULL, 2, 4] |
+---------------------------------------------------+
```"#,
argument(
name = "array",
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
),
argument(
name = "element",
description = "Element to be removed from the array."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ArrayRemove {
signature: Signature,
aliases: Vec<String>,
}
impl Default for ArrayRemove {
fn default() -> Self {
Self::new()
}
}
impl ArrayRemove {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(Volatility::Immutable),
aliases: vec!["list_remove".to_string()],
}
}
}
impl ScalarUDFImpl for ArrayRemove {
fn name(&self) -> &str {
"array_remove"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(
&self,
args: datafusion_expr::ReturnFieldArgs,
) -> Result<FieldRef> {
Ok(Arc::clone(&args.arg_fields[0]))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_inner)(&args.args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
make_udf_expr_and_func!(
ArrayRemoveN,
array_remove_n,
array element max,
"removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
array_remove_n_udf
);
#[user_doc(
doc_section(label = "Array Functions"),
description = "Removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
syntax_example = "array_remove_n(array, element, max)",
sql_example = r#"```sql
> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
+---------------------------------------------------------+
| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
+---------------------------------------------------------+
| [1, 3, 2, 1, 4] |
+---------------------------------------------------------+
> select array_remove_n([1, 2, NULL, 2, 4], 2, 2);
+----------------------------------------------------------+
| array_remove_n(List([1,2,NULL,2,4]),Int64(2),Int64(2)) |
+----------------------------------------------------------+
| [1, NULL, 4] |
+----------------------------------------------------------+
```"#,
argument(
name = "array",
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
),
argument(
name = "element",
description = "Element to be removed from the array."
),
argument(name = "max", description = "Number of first occurrences to remove.")
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub(super) struct ArrayRemoveN {
signature: Signature,
aliases: Vec<String>,
}
impl ArrayRemoveN {
pub fn new() -> Self {
Self {
signature: Signature::new(
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Element,
ArrayFunctionArgument::Index,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
}),
Volatility::Immutable,
),
aliases: vec!["list_remove_n".to_string()],
}
}
}
impl ScalarUDFImpl for ArrayRemoveN {
fn name(&self) -> &str {
"array_remove_n"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(
&self,
args: datafusion_expr::ReturnFieldArgs,
) -> Result<FieldRef> {
Ok(Arc::clone(&args.arg_fields[0]))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_n_inner)(&args.args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
make_udf_expr_and_func!(
ArrayRemoveAll,
array_remove_all,
array element,
"removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
array_remove_all_udf
);
#[user_doc(
doc_section(label = "Array Functions"),
description = "Removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
syntax_example = "array_remove_all(array, element)",
sql_example = r#"```sql
> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
+--------------------------------------------------+
| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
+--------------------------------------------------+
| [1, 3, 1, 4] |
+--------------------------------------------------+
> select array_remove_all([1, 2, NULL, 2, 4], 2);
+-----------------------------------------------------+
| array_remove_all(List([1,2,NULL,2,4]),Int64(2)) |
+-----------------------------------------------------+
| [1, NULL, 4] |
+-----------------------------------------------------+
```"#,
argument(
name = "array",
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
),
argument(
name = "element",
description = "Element to be removed from the array."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub(super) struct ArrayRemoveAll {
signature: Signature,
aliases: Vec<String>,
}
impl ArrayRemoveAll {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(Volatility::Immutable),
aliases: vec!["list_remove_all".to_string()],
}
}
}
impl ScalarUDFImpl for ArrayRemoveAll {
fn name(&self) -> &str {
"array_remove_all"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(
&self,
args: datafusion_expr::ReturnFieldArgs,
) -> Result<FieldRef> {
Ok(Arc::clone(&args.arg_fields[0]))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_all_inner)(&args.args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove", args)?;
let arr_n = vec![1; array.len()];
array_remove_internal(array, element, &arr_n)
}
fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element, max] = take_function_args("array_remove_n", args)?;
let arr_n = as_int64_array(max)?.values().to_vec();
array_remove_internal(array, element, &arr_n)
}
fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove_all", args)?;
let arr_n = vec![i64::MAX; array.len()];
array_remove_internal(array, element, &arr_n)
}
fn array_remove_internal(
array: &ArrayRef,
element_array: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_remove::<i32>(list_array, element_array, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_remove::<i64>(list_array, element_array, arr_n)
}
array_type => {
exec_err!("array_remove_all does not support type '{array_type}'.")
}
}
}
/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
/// of `element_array[i]`.
///
/// The type of each **element** in `list_array` must be the same as the type of
/// `element_array`. This function also handles nested arrays
/// ([`arrow::array::ListArray`] of [`arrow::array::ListArray`]s)
///
/// For example, when called to remove a list array (where each element is a
/// list of int32s, the second argument are int32 arrays, and the
/// third argument is the number of occurrences to remove
///
/// ```text
/// general_remove(
/// [1, 2, 3, 2], 2, 1 ==> [1, 3, 2] (only the first 2 is removed)
/// [4, 5, 6, 5], 5, 2 ==> [4, 6] (both 5s are removed)
/// )
/// ```
fn general_remove<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
let list_field = match list_array.data_type() {
DataType::List(field) | DataType::LargeList(field) => field,
_ => {
return exec_err!(
"Expected List or LargeList data type, got {:?}",
list_array.data_type()
);
}
};
let original_data = list_array.values().to_data();
// Build up the offsets for the final output array
let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
offsets.push(OffsetSize::zero());
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data],
false,
Capacities::Array(original_data.len()),
);
let mut valid = NullBufferBuilder::new(list_array.len());
for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
if list_array.is_null(row_index) || element_array.is_null(row_index) {
offsets.push(offsets[row_index]);
valid.append_null();
continue;
}
let start = offset_window[0].to_usize().unwrap();
let end = offset_window[1].to_usize().unwrap();
// n is the number of elements to remove in this row
let n = arr_n[row_index];
// compare each element in the list, `false` means the element matches and should be removed
let eq_array = utils::compare_element_to_list(
&list_array.value(row_index),
element_array,
row_index,
false,
)?;
let num_to_remove = eq_array.false_count();
// Fast path: no elements to remove, copy entire row
if num_to_remove == 0 {
mutable.extend(0, start, end);
offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
valid.append_non_null();
continue;
}
// Remove at most `n` matching elements
let max_removals = n.min(num_to_remove as i64);
let mut removed = 0i64;
let mut copied = 0usize;
// marks the beginning of a range of elements pending to be copied.
let mut pending_batch_to_retain: Option<usize> = None;
for (i, keep) in eq_array.iter().enumerate() {
if keep == Some(false) && removed < max_removals {
// Flush pending batch before skipping this element
if let Some(bs) = pending_batch_to_retain {
mutable.extend(0, start + bs, start + i);
copied += i - bs;
pending_batch_to_retain = None;
}
removed += 1;
} else if pending_batch_to_retain.is_none() {
pending_batch_to_retain = Some(i);
}
}
// Flush remaining batch
if let Some(bs) = pending_batch_to_retain {
mutable.extend(0, start + bs, start + eq_array.len());
copied += eq_array.len() - bs;
}
offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
valid.append_non_null();
}
let new_values = make_array(mutable.freeze());
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
Arc::clone(list_field),
OffsetBuffer::new(offsets.into()),
new_values,
valid.finish(),
)?))
}
#[cfg(test)]
mod tests {
use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
use arrow::array::{
Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
};
use arrow::datatypes::{DataType, Field, Int32Type};
use datafusion_common::ScalarValue;
use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
use datafusion_expr_common::columnar_value::ColumnarValue;
use std::ops::Deref;
use std::sync::Arc;
#[test]
fn test_array_remove_nullability() {
for nullability in [true, false] {
for item_nullability in [true, false] {
let input_field = Arc::new(Field::new(
"num",
DataType::new_list(DataType::Int32, item_nullability),
nullability,
));
let args_fields = vec![
Arc::clone(&input_field),
Arc::new(Field::new("a", DataType::Int32, false)),
];
let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))];
let result = ArrayRemove::new()
.return_field_from_args(ReturnFieldArgs {
arg_fields: &args_fields,
scalar_arguments: &scalar_args,
})
.unwrap();
assert_eq!(result, input_field);
}
}
}
#[test]
fn test_array_remove_n_nullability() {
for nullability in [true, false] {
for item_nullability in [true, false] {
let input_field = Arc::new(Field::new(
"num",
DataType::new_list(DataType::Int32, item_nullability),
nullability,
));
let args_fields = vec![
Arc::clone(&input_field),
Arc::new(Field::new("a", DataType::Int32, false)),
Arc::new(Field::new("b", DataType::Int64, false)),
];
let scalar_args = vec![
None,
Some(&ScalarValue::Int32(Some(1))),
Some(&ScalarValue::Int64(Some(1))),
];
let result = ArrayRemoveN::new()
.return_field_from_args(ReturnFieldArgs {
arg_fields: &args_fields,
scalar_arguments: &scalar_args,
})
.unwrap();
assert_eq!(result, input_field);
}
}
}
#[test]
fn test_array_remove_all_nullability() {
for nullability in [true, false] {
for item_nullability in [true, false] {
let input_field = Arc::new(Field::new(
"num",
DataType::new_list(DataType::Int32, item_nullability),
nullability,
));
let result = ArrayRemoveAll::new()
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&input_field)],
scalar_arguments: &[None],
})
.unwrap();
assert_eq!(result, input_field);
}
}
}
fn ensure_field_nullability<O: OffsetSizeTrait>(
field_nullable: bool,
list: GenericListArray<O>,
) -> GenericListArray<O> {
let (field, offsets, values, nulls) = list.into_parts();
if field.is_nullable() == field_nullable {
return GenericListArray::new(field, offsets, values, nulls);
}
if !field_nullable {
assert_eq!(nulls, None);
}
let field = Arc::new(field.deref().clone().with_nullable(field_nullable));
GenericListArray::new(field, offsets, values, nulls)
}
#[test]
fn test_array_remove_non_nullable() {
let input_list = Arc::new(ensure_field_nullability(
false,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
]),
));
let expected_list = ensure_field_nullability(
false,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
Some(([42, 55, 63, 2]).iter().copied().map(Some)),
]),
);
let element_to_remove = ScalarValue::Int32(Some(2));
assert_array_remove(input_list, expected_list, element_to_remove);
}
#[test]
fn test_array_remove_nullable() {
let input_list = Arc::new(ensure_field_nullability(
true,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![
Some(1),
Some(2),
Some(2),
Some(3),
None,
Some(1),
Some(4),
]),
Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
]),
));
let expected_list = ensure_field_nullability(
true,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
Some(vec![Some(42), None, Some(63), Some(2)]),
]),
);
let element_to_remove = ScalarValue::Int32(Some(2));
assert_array_remove(input_list, expected_list, element_to_remove);
}
fn assert_array_remove(
input_list: ArrayRef,
expected_list: GenericListArray<i32>,
element_to_remove: ScalarValue,
) {
assert_eq!(input_list.data_type(), expected_list.data_type());
assert_eq!(expected_list.value_type(), element_to_remove.data_type());
let input_list_len = input_list.len();
let input_list_data_type = input_list.data_type().clone();
let udf = ArrayRemove::new();
let args_fields = vec![
Arc::new(Field::new("num", input_list.data_type().clone(), false)),
Arc::new(Field::new(
"el",
element_to_remove.data_type(),
element_to_remove.is_null(),
)),
];
let scalar_args = vec![None, Some(&element_to_remove)];
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &args_fields,
scalar_arguments: &scalar_args,
})
.unwrap();
let result = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input_list),
ColumnarValue::Scalar(element_to_remove),
],
arg_fields: args_fields,
number_rows: input_list_len,
return_field,
config_options: Arc::new(Default::default()),
})
.unwrap();
assert_eq!(result.data_type(), input_list_data_type);
match result {
ColumnarValue::Array(array) => {
let result_list = array.as_list::<i32>();
assert_eq!(result_list, &expected_list);
}
_ => panic!("Expected ColumnarValue::Array"),
}
}
#[test]
fn test_array_remove_n_non_nullable() {
let input_list = Arc::new(ensure_field_nullability(
false,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
]),
));
let expected_list = ensure_field_nullability(
false,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
Some(([42, 55, 63]).iter().copied().map(Some)),
]),
);
let element_to_remove = ScalarValue::Int32(Some(2));
assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
}
#[test]
fn test_array_remove_n_nullable() {
let input_list = Arc::new(ensure_field_nullability(
true,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![
Some(1),
Some(2),
Some(2),
Some(3),
None,
Some(1),
Some(4),
]),
Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
]),
));
let expected_list = ensure_field_nullability(
true,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
Some(vec![Some(42), None, Some(63)]),
]),
);
let element_to_remove = ScalarValue::Int32(Some(2));
assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
}
fn assert_array_remove_n(
input_list: ArrayRef,
expected_list: GenericListArray<i32>,
element_to_remove: ScalarValue,
n: i64,
) {
assert_eq!(input_list.data_type(), expected_list.data_type());
assert_eq!(expected_list.value_type(), element_to_remove.data_type());
let input_list_len = input_list.len();
let input_list_data_type = input_list.data_type().clone();
let count_scalar = ScalarValue::Int64(Some(n));
let udf = ArrayRemoveN::new();
let args_fields = vec![
Arc::new(Field::new("num", input_list.data_type().clone(), false)),
Arc::new(Field::new(
"el",
element_to_remove.data_type(),
element_to_remove.is_null(),
)),
Arc::new(Field::new("count", DataType::Int64, false)),
];
let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)];
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &args_fields,
scalar_arguments: &scalar_args,
})
.unwrap();
let result = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input_list),
ColumnarValue::Scalar(element_to_remove),
ColumnarValue::Scalar(count_scalar),
],
arg_fields: args_fields,
number_rows: input_list_len,
return_field,
config_options: Arc::new(Default::default()),
})
.unwrap();
assert_eq!(result.data_type(), input_list_data_type);
match result {
ColumnarValue::Array(array) => {
let result_list = array.as_list::<i32>();
assert_eq!(result_list, &expected_list);
}
_ => panic!("Expected ColumnarValue::Array"),
}
}
#[test]
fn test_array_remove_all_non_nullable() {
let input_list = Arc::new(ensure_field_nullability(
false,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
]),
));
let expected_list = ensure_field_nullability(
false,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(([1, 3, 1, 4]).iter().copied().map(Some)),
Some(([42, 55, 63]).iter().copied().map(Some)),
]),
);
let element_to_remove = ScalarValue::Int32(Some(2));
assert_array_remove_all(input_list, expected_list, element_to_remove);
}
#[test]
fn test_array_remove_all_nullable() {
let input_list = Arc::new(ensure_field_nullability(
true,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![
Some(1),
Some(2),
Some(2),
Some(3),
None,
Some(1),
Some(4),
]),
Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
]),
));
let expected_list = ensure_field_nullability(
true,
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
Some(vec![Some(42), None, Some(63)]),
]),
);
let element_to_remove = ScalarValue::Int32(Some(2));
assert_array_remove_all(input_list, expected_list, element_to_remove);
}
fn assert_array_remove_all(
input_list: ArrayRef,
expected_list: GenericListArray<i32>,
element_to_remove: ScalarValue,
) {
assert_eq!(input_list.data_type(), expected_list.data_type());
assert_eq!(expected_list.value_type(), element_to_remove.data_type());
let input_list_len = input_list.len();
let input_list_data_type = input_list.data_type().clone();
let udf = ArrayRemoveAll::new();
let args_fields = vec![
Arc::new(Field::new("num", input_list.data_type().clone(), false)),
Arc::new(Field::new(
"el",
element_to_remove.data_type(),
element_to_remove.is_null(),
)),
];
let scalar_args = vec![None, Some(&element_to_remove)];
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &args_fields,
scalar_arguments: &scalar_args,
})
.unwrap();
let result = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input_list),
ColumnarValue::Scalar(element_to_remove),
],
arg_fields: args_fields,
number_rows: input_list_len,
return_field,
config_options: Arc::new(Default::default()),
})
.unwrap();
assert_eq!(result.data_type(), input_list_data_type);
match result {
ColumnarValue::Array(array) => {
let result_list = array.as_list::<i32>();
assert_eq!(result_list, &expected_list);
}
_ => panic!("Expected ColumnarValue::Array"),
}
}
}