blob: ba8e1b758d86c758d4f4799a0f09fd4b2006a0ca [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Defines miscellaneous array kernels.
use crate::array::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;
use crate::{
bitmap::Bitmap,
buffer::{Buffer, MutableBuffer},
util::bit_util,
};
use std::{mem, sync::Arc};
/// trait for copying filtered null bitmap bits
trait CopyNullBit {
fn copy_null_bit(&mut self, source_index: usize);
fn copy_null_bits(&mut self, source_index: usize, count: usize);
fn null_count(&self) -> usize;
fn null_buffer(&mut self) -> Buffer;
}
/// no-op null bitmap copy implementation,
/// used when the filtered data array doesn't have a null bitmap
struct NullBitNoop {}
impl NullBitNoop {
fn new() -> Self {
NullBitNoop {}
}
}
impl CopyNullBit for NullBitNoop {
#[inline]
fn copy_null_bit(&mut self, _source_index: usize) {
// do nothing
}
#[inline]
fn copy_null_bits(&mut self, _source_index: usize, _count: usize) {
// do nothing
}
fn null_count(&self) -> usize {
0
}
fn null_buffer(&mut self) -> Buffer {
Buffer::from([0u8; 0])
}
}
/// null bitmap copy implementation,
/// used when the filtered data array has a null bitmap
struct NullBitSetter<'a> {
target_buffer: MutableBuffer,
source_bytes: &'a [u8],
target_index: usize,
null_count: usize,
}
impl<'a> NullBitSetter<'a> {
fn new(null_bitmap: &'a Bitmap) -> Self {
let null_bytes = null_bitmap.buffer_ref().data();
// create null bitmap buffer with same length and initialize null bitmap buffer to 1s
let null_buffer =
MutableBuffer::new(null_bytes.len()).with_bitset(null_bytes.len(), true);
NullBitSetter {
source_bytes: null_bytes,
target_buffer: null_buffer,
target_index: 0,
null_count: 0,
}
}
}
impl<'a> CopyNullBit for NullBitSetter<'a> {
#[inline]
fn copy_null_bit(&mut self, source_index: usize) {
if !bit_util::get_bit(self.source_bytes, source_index) {
bit_util::unset_bit(self.target_buffer.data_mut(), self.target_index);
self.null_count += 1;
}
self.target_index += 1;
}
#[inline]
fn copy_null_bits(&mut self, source_index: usize, count: usize) {
for i in 0..count {
self.copy_null_bit(source_index + i);
}
}
fn null_count(&self) -> usize {
self.null_count
}
fn null_buffer(&mut self) -> Buffer {
self.target_buffer.resize(self.target_index);
// use mem::replace to detach self.target_buffer from self so that it can be returned
let target_buffer = mem::replace(&mut self.target_buffer, MutableBuffer::new(0));
target_buffer.freeze()
}
}
fn get_null_bit_setter<'a>(data_array: &'a impl Array) -> Box<CopyNullBit + 'a> {
if let Some(null_bitmap) = data_array.data_ref().null_bitmap() {
// only return an actual null bit copy implementation if null_bitmap is set
Box::new(NullBitSetter::new(null_bitmap))
} else {
// otherwise return a no-op copy null bit implementation
// for improved performance when the filtered array doesn't contain NULLs
Box::new(NullBitNoop::new())
}
}
// transmute filter array to u64
// - optimize filtering with highly selective filters by skipping entire batches of 64 filter bits
// - if the data array being filtered doesn't have a null bitmap, no time is wasted to copy a null bitmap
fn filter_array_impl(
filter_context: &FilterContext,
data_array: &impl Array,
array_type: DataType,
value_size: usize,
) -> Result<ArrayDataBuilder> {
if filter_context.filter_len > data_array.len() {
return Err(ArrowError::ComputeError(
"Filter array cannot be larger than data array".to_string(),
));
}
let filtered_count = filter_context.filtered_count;
let filter_mask = &filter_context.filter_mask;
let filter_u64 = &filter_context.filter_u64;
let data_bytes = data_array.data_ref().buffers()[0].data();
let mut target_buffer = MutableBuffer::new(filtered_count * value_size);
target_buffer.resize(filtered_count * value_size);
let target_bytes = target_buffer.data_mut();
let mut target_byte_index: usize = 0;
let mut null_bit_setter = get_null_bit_setter(data_array);
let null_bit_setter = null_bit_setter.as_mut();
let all_ones_batch = !0u64;
let data_array_offset = data_array.offset();
for (i, filter_batch) in filter_u64.iter().enumerate() {
// foreach u64 batch
let filter_batch = *filter_batch;
if filter_batch == 0 {
// if batch == 0, all items are filtered out, so skip entire batch
continue;
} else if filter_batch == all_ones_batch {
// if batch == all 1s: copy all 64 values in one go
let data_index = (i * 64) + data_array_offset;
null_bit_setter.copy_null_bits(data_index, 64);
let data_byte_index = data_index * value_size;
let data_len = value_size * 64;
target_bytes[target_byte_index..(target_byte_index + data_len)]
.copy_from_slice(
&data_bytes[data_byte_index..(data_byte_index + data_len)],
);
target_byte_index += data_len;
continue;
}
for (j, filter_mask) in filter_mask.iter().enumerate() {
// foreach bit in batch:
if (filter_batch & *filter_mask) != 0 {
let data_index = (i * 64) + j + data_array_offset;
null_bit_setter.copy_null_bit(data_index);
// if filter bit == 1: copy data value bytes
let data_byte_index = data_index * value_size;
target_bytes[target_byte_index..(target_byte_index + value_size)]
.copy_from_slice(
&data_bytes[data_byte_index..(data_byte_index + value_size)],
);
target_byte_index += value_size;
}
}
}
let mut array_data_builder = ArrayDataBuilder::new(array_type)
.len(filtered_count)
.add_buffer(target_buffer.freeze());
if null_bit_setter.null_count() > 0 {
array_data_builder = array_data_builder
.null_count(null_bit_setter.null_count())
.null_bit_buffer(null_bit_setter.null_buffer());
}
Ok(array_data_builder)
}
/// FilterContext can be used to improve performance when
/// filtering multiple data arrays with the same filter array.
#[derive(Debug)]
pub struct FilterContext {
filter_u64: Vec<u64>,
filter_len: usize,
filtered_count: usize,
filter_mask: Vec<u64>,
}
macro_rules! filter_primitive_array {
($context:expr, $array:expr, $array_type:ident) => {{
let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap();
let output_array = $context.filter_primitive_array(input_array)?;
Ok(Arc::new(output_array))
}};
}
macro_rules! filter_dictionary_array {
($context:expr, $array:expr, $array_type:ident) => {{
let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap();
let output_array = $context.filter_dictionary_array(input_array)?;
Ok(Arc::new(output_array))
}};
}
macro_rules! filter_primitive_item_list_array {
($context:expr, $array:expr, $item_type:ident, $list_type:ident, $list_builder_type:ident) => {{
let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap();
let values_builder = PrimitiveBuilder::<$item_type>::new($context.filtered_count);
let mut builder = $list_builder_type::new(values_builder);
for i in 0..$context.filter_u64.len() {
// foreach u64 batch
let filter_batch = $context.filter_u64[i];
if filter_batch == 0 {
// if batch == 0, all items are filtered out, so skip entire batch
continue;
}
for j in 0..64 {
// foreach bit in batch:
if (filter_batch & $context.filter_mask[j]) != 0 {
let data_index = (i * 64) + j;
if input_array.is_null(data_index) {
builder.append(false)?;
} else {
let this_inner_list = input_array.value(data_index);
let inner_list = this_inner_list
.as_any()
.downcast_ref::<PrimitiveArray<$item_type>>()
.unwrap();
for k in 0..inner_list.len() {
if inner_list.is_null(k) {
builder.values().append_null()?;
} else {
builder.values().append_value(inner_list.value(k))?;
}
}
builder.append(true)?;
}
}
}
}
Ok(Arc::new(builder.finish()))
}};
}
macro_rules! filter_non_primitive_item_list_array {
($context:expr, $array:expr, $item_array_type:ident, $item_builder:ident, $list_type:ident, $list_builder_type:ident) => {{
let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap();
let values_builder = $item_builder::new($context.filtered_count);
let mut builder = $list_builder_type::new(values_builder);
for i in 0..$context.filter_u64.len() {
// foreach u64 batch
let filter_batch = $context.filter_u64[i];
if filter_batch == 0 {
// if batch == 0, all items are filtered out, so skip entire batch
continue;
}
for j in 0..64 {
// foreach bit in batch:
if (filter_batch & $context.filter_mask[j]) != 0 {
let data_index = (i * 64) + j;
if input_array.is_null(data_index) {
builder.append(false)?;
} else {
let this_inner_list = input_array.value(data_index);
let inner_list = this_inner_list
.as_any()
.downcast_ref::<$item_array_type>()
.unwrap();
for k in 0..inner_list.len() {
if inner_list.is_null(k) {
builder.values().append_null()?;
} else {
builder.values().append_value(inner_list.value(k))?;
}
}
builder.append(true)?;
}
}
}
}
Ok(Arc::new(builder.finish()))
}};
}
impl FilterContext {
/// Returns a new instance of FilterContext
pub fn new(filter_array: &BooleanArray) -> Result<Self> {
if filter_array.offset() > 0 {
return Err(ArrowError::ComputeError(
"Filter array cannot have offset > 0".to_string(),
));
}
let filter_mask: Vec<u64> = (0..64).map(|x| 1u64 << x).collect();
let filter_buffer = &filter_array.data_ref().buffers()[0];
let filtered_count = filter_buffer.count_set_bits_offset(0, filter_array.len());
let filter_bytes = filter_buffer.data();
// add to the resulting len so is is a multiple of the size of u64
let pad_addional_len = 8 - filter_bytes.len() % 8;
// transmute filter_bytes to &[u64]
let mut u64_buffer = MutableBuffer::new(filter_bytes.len() + pad_addional_len);
u64_buffer.extend_from_slice(filter_bytes);
u64_buffer.extend_from_slice(&vec![0; pad_addional_len]);
let mut filter_u64 = u64_buffer.typed_data_mut::<u64>().to_owned();
// mask of any bits outside of the given len
if filter_array.len() % 64 != 0 {
let last_idx = filter_u64.len() - 1;
let mask = u64::MAX >> (64 - filter_array.len() % 64);
filter_u64[last_idx] &= mask;
}
Ok(FilterContext {
filter_u64,
filter_len: filter_array.len(),
filtered_count,
filter_mask,
})
}
/// Returns a new array, containing only the elements matching the filter
pub fn filter(&self, array: &Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::UInt8 => filter_primitive_array!(self, array, UInt8Array),
DataType::UInt16 => filter_primitive_array!(self, array, UInt16Array),
DataType::UInt32 => filter_primitive_array!(self, array, UInt32Array),
DataType::UInt64 => filter_primitive_array!(self, array, UInt64Array),
DataType::Int8 => filter_primitive_array!(self, array, Int8Array),
DataType::Int16 => filter_primitive_array!(self, array, Int16Array),
DataType::Int32 => filter_primitive_array!(self, array, Int32Array),
DataType::Int64 => filter_primitive_array!(self, array, Int64Array),
DataType::Float32 => filter_primitive_array!(self, array, Float32Array),
DataType::Float64 => filter_primitive_array!(self, array, Float64Array),
DataType::Boolean => {
let input_array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
let mut builder = BooleanArray::builder(self.filtered_count);
for i in 0..self.filter_u64.len() {
// foreach u64 batch
let filter_batch = self.filter_u64[i];
if filter_batch == 0 {
// if batch == 0, all items are filtered out, so skip entire batch
continue;
}
for j in 0..64 {
// foreach bit in batch:
if (filter_batch & self.filter_mask[j]) != 0 {
let data_index = (i * 64) + j;
if input_array.is_null(data_index) {
builder.append_null()?;
} else {
builder.append_value(input_array.value(data_index))?;
}
}
}
}
Ok(Arc::new(builder.finish()))
},
DataType::Date32(_) => filter_primitive_array!(self, array, Date32Array),
DataType::Date64(_) => filter_primitive_array!(self, array, Date64Array),
DataType::Time32(TimeUnit::Second) => {
filter_primitive_array!(self, array, Time32SecondArray)
}
DataType::Time32(TimeUnit::Millisecond) => {
filter_primitive_array!(self, array, Time32MillisecondArray)
}
DataType::Time64(TimeUnit::Microsecond) => {
filter_primitive_array!(self, array, Time64MicrosecondArray)
}
DataType::Time64(TimeUnit::Nanosecond) => {
filter_primitive_array!(self, array, Time64NanosecondArray)
}
DataType::Duration(TimeUnit::Second) => {
filter_primitive_array!(self, array, DurationSecondArray)
}
DataType::Duration(TimeUnit::Millisecond) => {
filter_primitive_array!(self, array, DurationMillisecondArray)
}
DataType::Duration(TimeUnit::Microsecond) => {
filter_primitive_array!(self, array, DurationMicrosecondArray)
}
DataType::Duration(TimeUnit::Nanosecond) => {
filter_primitive_array!(self, array, DurationNanosecondArray)
}
DataType::Timestamp(TimeUnit::Second, _) => {
filter_primitive_array!(self, array, TimestampSecondArray)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
filter_primitive_array!(self, array, TimestampMillisecondArray)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
filter_primitive_array!(self, array, TimestampMicrosecondArray)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
filter_primitive_array!(self, array, TimestampNanosecondArray)
}
DataType::Binary => {
let input_array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
let mut values: Vec<Option<&[u8]>> = Vec::with_capacity(self.filtered_count);
for i in 0..self.filter_u64.len() {
// foreach u64 batch
let filter_batch = self.filter_u64[i];
if filter_batch == 0 {
// if batch == 0, all items are filtered out, so skip entire batch
continue;
}
for j in 0..64 {
// foreach bit in batch:
if (filter_batch & self.filter_mask[j]) != 0 {
let data_index = (i * 64) + j;
if input_array.is_null(data_index) {
values.push(None)
} else {
values.push(Some(input_array.value(data_index)))
}
}
}
}
Ok(Arc::new(BinaryArray::from(values)))
}
DataType::Utf8 => {
let input_array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut values: Vec<Option<&str>> = Vec::with_capacity(self.filtered_count);
for i in 0..self.filter_u64.len() {
// foreach u64 batch
let filter_batch = self.filter_u64[i];
if filter_batch == 0 {
// if batch == 0, all items are filtered out, so skip entire batch
continue;
}
for j in 0..64 {
// foreach bit in batch:
if (filter_batch & self.filter_mask[j]) != 0 {
let data_index = (i * 64) + j;
if input_array.is_null(data_index) {
values.push(None)
} else {
values.push(Some(input_array.value(data_index)))
}
}
}
}
Ok(Arc::new(StringArray::from(values)))
}
DataType::Dictionary(ref key_type, ref value_type) => match (key_type.as_ref(), value_type.as_ref()) {
(key_type, DataType::Utf8) => match key_type {
DataType::UInt8 => filter_dictionary_array!(self, array, UInt8DictionaryArray),
DataType::UInt16 => filter_dictionary_array!(self, array, UInt16DictionaryArray),
DataType::UInt32 => filter_dictionary_array!(self, array, UInt32DictionaryArray),
DataType::UInt64 => filter_dictionary_array!(self, array, UInt64DictionaryArray),
DataType::Int8 => filter_dictionary_array!(self, array, Int8DictionaryArray),
DataType::Int16 => filter_dictionary_array!(self, array, Int16DictionaryArray),
DataType::Int32 => filter_dictionary_array!(self, array, Int32DictionaryArray),
DataType::Int64 => filter_dictionary_array!(self, array, Int64DictionaryArray),
other => Err(ArrowError::ComputeError(format!(
"filter not supported for string dictionary with key of type {:?}",
other
)))
}
(key_type, value_type) => Err(ArrowError::ComputeError(format!(
"filter not supported for Dictionary({:?}, {:?})",
key_type, value_type
)))
}
DataType::List(dt) => match dt.data_type() {
DataType::UInt8 => {
filter_primitive_item_list_array!(self, array, UInt8Type, ListArray, ListBuilder)
}
DataType::UInt16 => {
filter_primitive_item_list_array!(self, array, UInt16Type, ListArray, ListBuilder)
}
DataType::UInt32 => {
filter_primitive_item_list_array!(self, array, UInt32Type, ListArray, ListBuilder)
}
DataType::UInt64 => {
filter_primitive_item_list_array!(self, array, UInt64Type, ListArray, ListBuilder)
}
DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, ListArray, ListBuilder),
DataType::Int16 => {
filter_primitive_item_list_array!(self, array, Int16Type, ListArray, ListBuilder)
}
DataType::Int32 => {
filter_primitive_item_list_array!(self, array, Int32Type, ListArray, ListBuilder)
}
DataType::Int64 => {
filter_primitive_item_list_array!(self, array, Int64Type, ListArray, ListBuilder)
}
DataType::Float32 => {
filter_primitive_item_list_array!(self, array, Float32Type, ListArray, ListBuilder)
}
DataType::Float64 => {
filter_primitive_item_list_array!(self, array, Float64Type, ListArray, ListBuilder)
}
DataType::Boolean => {
filter_primitive_item_list_array!(self, array, BooleanType, ListArray, ListBuilder)
}
DataType::Date32(_) => {
filter_primitive_item_list_array!(self, array, Date32Type, ListArray, ListBuilder)
}
DataType::Date64(_) => {
filter_primitive_item_list_array!(self, array, Date64Type, ListArray, ListBuilder)
}
DataType::Time32(TimeUnit::Second) => {
filter_primitive_item_list_array!(self, array, Time32SecondType, ListArray, ListBuilder)
}
DataType::Time32(TimeUnit::Millisecond) => {
filter_primitive_item_list_array!(self, array, Time32MillisecondType, ListArray, ListBuilder)
}
DataType::Time64(TimeUnit::Microsecond) => {
filter_primitive_item_list_array!(self, array, Time64MicrosecondType, ListArray, ListBuilder)
}
DataType::Time64(TimeUnit::Nanosecond) => {
filter_primitive_item_list_array!(self, array, Time64NanosecondType, ListArray, ListBuilder)
}
DataType::Duration(TimeUnit::Second) => {
filter_primitive_item_list_array!(self, array, DurationSecondType, ListArray, ListBuilder)
}
DataType::Duration(TimeUnit::Millisecond) => {
filter_primitive_item_list_array!(self, array, DurationMillisecondType, ListArray, ListBuilder)
}
DataType::Duration(TimeUnit::Microsecond) => {
filter_primitive_item_list_array!(self, array, DurationMicrosecondType, ListArray, ListBuilder)
}
DataType::Duration(TimeUnit::Nanosecond) => {
filter_primitive_item_list_array!(self, array, DurationNanosecondType, ListArray, ListBuilder)
}
DataType::Timestamp(TimeUnit::Second, _) => {
filter_primitive_item_list_array!(self, array, TimestampSecondType, ListArray, ListBuilder)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
filter_primitive_item_list_array!(self, array, TimestampMillisecondType, ListArray, ListBuilder)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, ListArray, ListBuilder)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
filter_primitive_item_list_array!(self, array, TimestampNanosecondType, ListArray, ListBuilder)
}
DataType::Binary => filter_non_primitive_item_list_array!(
self,
array,
BinaryArray,
BinaryBuilder,
ListArray,
ListBuilder
),
DataType::LargeBinary => filter_non_primitive_item_list_array!(
self,
array,
LargeBinaryArray,
LargeBinaryBuilder,
ListArray,
ListBuilder
),
DataType::Utf8 => filter_non_primitive_item_list_array!(
self,
array,
StringArray,
StringBuilder,
ListArray
,ListBuilder
),
DataType::LargeUtf8 => filter_non_primitive_item_list_array!(
self,
array,
LargeStringArray,
LargeStringBuilder,
ListArray,
ListBuilder
),
other => {
Err(ArrowError::ComputeError(format!(
"filter not supported for List({:?})",
other
)))
}
}
DataType::LargeList(dt) => match dt.data_type() {
DataType::UInt8 => {
filter_primitive_item_list_array!(self, array, UInt8Type, LargeListArray, LargeListBuilder)
}
DataType::UInt16 => {
filter_primitive_item_list_array!(self, array, UInt16Type, LargeListArray, LargeListBuilder)
}
DataType::UInt32 => {
filter_primitive_item_list_array!(self, array, UInt32Type, LargeListArray, LargeListBuilder)
}
DataType::UInt64 => {
filter_primitive_item_list_array!(self, array, UInt64Type, LargeListArray, LargeListBuilder)
}
DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, LargeListArray, LargeListBuilder),
DataType::Int16 => {
filter_primitive_item_list_array!(self, array, Int16Type, LargeListArray, LargeListBuilder)
}
DataType::Int32 => {
filter_primitive_item_list_array!(self, array, Int32Type, LargeListArray, LargeListBuilder)
}
DataType::Int64 => {
filter_primitive_item_list_array!(self, array, Int64Type, LargeListArray, LargeListBuilder)
}
DataType::Float32 => {
filter_primitive_item_list_array!(self, array, Float32Type, LargeListArray, LargeListBuilder)
}
DataType::Float64 => {
filter_primitive_item_list_array!(self, array, Float64Type, LargeListArray, LargeListBuilder)
}
DataType::Boolean => {
filter_primitive_item_list_array!(self, array, BooleanType, LargeListArray, LargeListBuilder)
}
DataType::Date32(_) => {
filter_primitive_item_list_array!(self, array, Date32Type, LargeListArray, LargeListBuilder)
}
DataType::Date64(_) => {
filter_primitive_item_list_array!(self, array, Date64Type, LargeListArray, LargeListBuilder)
}
DataType::Time32(TimeUnit::Second) => {
filter_primitive_item_list_array!(self, array, Time32SecondType, LargeListArray, LargeListBuilder)
}
DataType::Time32(TimeUnit::Millisecond) => {
filter_primitive_item_list_array!(self, array, Time32MillisecondType, LargeListArray, LargeListBuilder)
}
DataType::Time64(TimeUnit::Microsecond) => {
filter_primitive_item_list_array!(self, array, Time64MicrosecondType, LargeListArray, LargeListBuilder)
}
DataType::Time64(TimeUnit::Nanosecond) => {
filter_primitive_item_list_array!(self, array, Time64NanosecondType, LargeListArray, LargeListBuilder)
}
DataType::Duration(TimeUnit::Second) => {
filter_primitive_item_list_array!(self, array, DurationSecondType, LargeListArray, LargeListBuilder)
}
DataType::Duration(TimeUnit::Millisecond) => {
filter_primitive_item_list_array!(self, array, DurationMillisecondType, LargeListArray, LargeListBuilder)
}
DataType::Duration(TimeUnit::Microsecond) => {
filter_primitive_item_list_array!(self, array, DurationMicrosecondType, LargeListArray, LargeListBuilder)
}
DataType::Duration(TimeUnit::Nanosecond) => {
filter_primitive_item_list_array!(self, array, DurationNanosecondType, LargeListArray, LargeListBuilder)
}
DataType::Timestamp(TimeUnit::Second, _) => {
filter_primitive_item_list_array!(self, array, TimestampSecondType, LargeListArray, LargeListBuilder)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
filter_primitive_item_list_array!(self, array, TimestampMillisecondType, LargeListArray, LargeListBuilder)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, LargeListArray, LargeListBuilder)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
filter_primitive_item_list_array!(self, array, TimestampNanosecondType, LargeListArray, LargeListBuilder)
}
DataType::Binary => filter_non_primitive_item_list_array!(
self,
array,
BinaryArray,
BinaryBuilder,
LargeListArray,
LargeListBuilder
),
DataType::LargeBinary => filter_non_primitive_item_list_array!(
self,
array,
LargeBinaryArray,
LargeBinaryBuilder,
LargeListArray,
LargeListBuilder
),
DataType::Utf8 => filter_non_primitive_item_list_array!(
self,
array,
StringArray,
StringBuilder,
LargeListArray,
LargeListBuilder
),
DataType::LargeUtf8 => filter_non_primitive_item_list_array!(
self,
array,
LargeStringArray,
LargeStringBuilder,
LargeListArray,
LargeListBuilder
),
other => {
Err(ArrowError::ComputeError(format!(
"filter not supported for LargeList({:?})",
other
)))
}
}
other => Err(ArrowError::ComputeError(format!(
"filter not supported for {:?}",
other
))),
}
}
/// Returns a new PrimitiveArray<T> containing only those values from the array passed as the data_array parameter,
/// selected by the BooleanArray passed as the filter_array parameter
pub fn filter_primitive_array<T>(
&self,
data_array: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
{
let array_type = T::DATA_TYPE;
let value_size = mem::size_of::<T::Native>();
let array_data_builder =
filter_array_impl(self, data_array, array_type, value_size)?;
let data = array_data_builder.build();
Ok(PrimitiveArray::<T>::from(data))
}
/// Returns a new DictionaryArray<T> containing only those keys from the array passed as the data_array parameter,
/// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array.
pub fn filter_dictionary_array<T>(
&self,
data_array: &DictionaryArray<T>,
) -> Result<DictionaryArray<T>>
where
T: ArrowNumericType,
{
let array_type = data_array.data_type().clone();
let value_size = mem::size_of::<T::Native>();
let mut array_data_builder =
filter_array_impl(self, data_array, array_type, value_size)?;
// copy dictionary values from input array
array_data_builder =
array_data_builder.add_child_data(data_array.values().data());
let data = array_data_builder.build();
Ok(DictionaryArray::<T>::from(data))
}
}
/// Returns a new array, containing only the elements matching the filter.
pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> {
FilterContext::new(filter)?.filter(array)
}
/// Returns a new PrimitiveArray<T> containing only those values from the array passed as the data_array parameter,
/// selected by the BooleanArray passed as the filter_array parameter
pub fn filter_primitive_array<T>(
data_array: &PrimitiveArray<T>,
filter_array: &BooleanArray,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
{
FilterContext::new(filter_array)?.filter_primitive_array(data_array)
}
/// Returns a new DictionaryArray<T> containing only those keys from the array passed as the data_array parameter,
/// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array.
pub fn filter_dictionary_array<T>(
data_array: &DictionaryArray<T>,
filter_array: &BooleanArray,
) -> Result<DictionaryArray<T>>
where
T: ArrowNumericType,
{
FilterContext::new(filter_array)?.filter_dictionary_array(data_array)
}
/// Returns a new RecordBatch with arrays containing only values matching the filter.
/// The same FilterContext is re-used when filtering arrays in the RecordBatch for better performance.
pub fn filter_record_batch(
record_batch: &RecordBatch,
filter_array: &BooleanArray,
) -> Result<RecordBatch> {
let filter_context = FilterContext::new(filter_array)?;
let filtered_arrays = record_batch
.columns()
.iter()
.map(|a| filter_context.filter(a.as_ref()))
.collect::<Result<Vec<ArrayRef>>>()?;
RecordBatch::try_new(record_batch.schema(), filtered_arrays)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::buffer::Buffer;
use crate::datatypes::ToByteSlice;
macro_rules! def_temporal_test {
($test:ident, $array_type: ident, $data: expr) => {
#[test]
fn $test() {
let a = $data;
let b = BooleanArray::from(vec![true, false, true, false]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
assert_eq!(2, d.len());
assert_eq!(1, d.value(0));
assert_eq!(3, d.value(1));
}
};
}
def_temporal_test!(
test_filter_date32,
Date32Array,
Date32Array::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_date64,
Date64Array,
Date64Array::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time32_second,
Time32SecondArray,
Time32SecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time32_millisecond,
Time32MillisecondArray,
Time32MillisecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time64_microsecond,
Time64MicrosecondArray,
Time64MicrosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time64_nanosecond,
Time64NanosecondArray,
Time64NanosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_second,
DurationSecondArray,
DurationSecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_millisecond,
DurationMillisecondArray,
DurationMillisecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_microsecond,
DurationMicrosecondArray,
DurationMicrosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_nanosecond,
DurationNanosecondArray,
DurationNanosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_timestamp_second,
TimestampSecondArray,
TimestampSecondArray::from_vec(vec![1, 2, 3, 4], None)
);
def_temporal_test!(
test_filter_timestamp_millisecond,
TimestampMillisecondArray,
TimestampMillisecondArray::from_vec(vec![1, 2, 3, 4], None)
);
def_temporal_test!(
test_filter_timestamp_microsecond,
TimestampMicrosecondArray,
TimestampMicrosecondArray::from_vec(vec![1, 2, 3, 4], None)
);
def_temporal_test!(
test_filter_timestamp_nanosecond,
TimestampNanosecondArray,
TimestampNanosecondArray::from_vec(vec![1, 2, 3, 4], None)
);
#[test]
fn test_filter_array() {
let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
let b = BooleanArray::from(vec![true, false, false, true, false]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, d.len());
assert_eq!(5, d.value(0));
assert_eq!(8, d.value(1));
}
#[test]
fn test_filter_array_slice() {
let a_slice = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
let a = a_slice.as_ref();
let b = BooleanArray::from(vec![true, false, false, true]);
// filtering with sliced filter array is not currently supported
// let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
// let b = b_slice.as_any().downcast_ref().unwrap();
let c = filter(a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, d.len());
assert_eq!(6, d.value(0));
assert_eq!(9, d.value(1));
}
#[test]
fn test_filter_array_low_density() {
// this test exercises the all 0's branch of the filter algorithm
let mut data_values = (1..=65).collect::<Vec<i32>>();
let mut filter_values =
(1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
// set up two more values after the batch
data_values.extend_from_slice(&[66, 67]);
filter_values.extend_from_slice(&[false, true]);
let a = Int32Array::from(data_values);
let b = BooleanArray::from(filter_values);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, d.len());
assert_eq!(65, d.value(0));
assert_eq!(67, d.value(1));
}
#[test]
fn test_filter_array_high_density() {
// this test exercises the all 1's branch of the filter algorithm
let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
let mut filter_values = (1..=65)
.map(|i| !matches!(i % 65, 0))
.collect::<Vec<bool>>();
// set second data value to null
data_values[1] = None;
// set up two more values after the batch
data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
filter_values.extend_from_slice(&[false, true, true, true]);
let a = Int32Array::from(data_values);
let b = BooleanArray::from(filter_values);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(67, d.len());
assert_eq!(3, d.null_count());
assert_eq!(1, d.value(0));
assert_eq!(true, d.is_null(1));
assert_eq!(64, d.value(63));
assert_eq!(true, d.is_null(64));
assert_eq!(67, d.value(65));
}
#[test]
fn test_filter_string_array() {
let a = StringArray::from(vec!["hello", " ", "world", "!"]);
let b = BooleanArray::from(vec![true, false, true, false]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!("hello", d.value(0));
assert_eq!("world", d.value(1));
}
#[test]
fn test_filter_primative_array_with_null() {
let a = Int32Array::from(vec![Some(5), None]);
let b = BooleanArray::from(vec![false, true]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(1, d.len());
assert_eq!(true, d.is_null(0));
}
#[test]
fn test_filter_string_array_with_null() {
let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
let b = BooleanArray::from(vec![true, false, false, true]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!("hello", d.value(0));
assert_eq!(false, d.is_null(0));
assert_eq!(true, d.is_null(1));
}
#[test]
fn test_filter_binary_array_with_null() {
let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
let a = BinaryArray::from(data);
let b = BooleanArray::from(vec![true, false, false, true]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!(b"hello", d.value(0));
assert_eq!(false, d.is_null(0));
assert_eq!(true, d.is_null(1));
}
#[test]
fn test_filter_array_slice_with_null() {
let a_slice =
Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
let a = a_slice.as_ref();
let b = BooleanArray::from(vec![true, false, false, true]);
// filtering with sliced filter array is not currently supported
// let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
// let b = b_slice.as_any().downcast_ref().unwrap();
let c = filter(a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, d.len());
assert_eq!(true, d.is_null(0));
assert_eq!(false, d.is_null(1));
assert_eq!(9, d.value(1));
}
#[test]
fn test_filter_dictionary_array() {
let values = vec![Some("hello"), None, Some("world"), Some("!")];
let a: Int8DictionaryArray = values.iter().copied().collect();
let b = BooleanArray::from(vec![false, true, true, false]);
let c = filter(&a, &b).unwrap();
let d = c
.as_ref()
.as_any()
.downcast_ref::<Int8DictionaryArray>()
.unwrap();
let value_array = d.values();
let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
// values are cloned in the filtered dictionary array
assert_eq!(3, values.len());
// but keys are filtered
assert_eq!(2, d.len());
assert_eq!(true, d.is_null(0));
assert_eq!("world", values.value(d.keys().value(1) as usize));
}
#[test]
fn test_filter_string_array_with_negated_boolean_array() {
let a = StringArray::from(vec!["hello", " ", "world", "!"]);
let mut bb = BooleanBuilder::new(2);
bb.append_value(false).unwrap();
bb.append_value(true).unwrap();
bb.append_value(false).unwrap();
bb.append_value(true).unwrap();
let b = bb.finish();
let b = crate::compute::not(&b).unwrap();
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!("hello", d.value(0));
assert_eq!("world", d.value(1));
}
#[test]
fn test_filter_list_array() {
let value_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice()))
.build();
let value_offsets = Buffer::from(&[0i64, 3, 6, 8, 8].to_byte_slice());
let list_data_type =
DataType::LargeList(Box::new(NullableDataType::new(DataType::Int32, false)));
let list_data = ArrayData::builder(list_data_type)
.len(4)
.add_buffer(value_offsets)
.add_child_data(value_data)
.null_bit_buffer(Buffer::from([0b00000111]))
.build();
// a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
let a = LargeListArray::from(list_data);
let b = BooleanArray::from(vec![false, true, false, true]);
let c = filter(&a, &b).unwrap();
let d = c
.as_ref()
.as_any()
.downcast_ref::<LargeListArray>()
.unwrap();
assert_eq!(DataType::Int32, d.value_type());
// result should be [[3, 4, 5], null]
assert_eq!(2, d.len());
assert_eq!(1, d.null_count());
assert_eq!(true, d.is_null(1));
assert_eq!(0, d.value_offset(0));
assert_eq!(3, d.value_length(0));
assert_eq!(3, d.value_offset(1));
assert_eq!(0, d.value_length(1));
assert_eq!(
Buffer::from(&[3, 4, 5].to_byte_slice()),
d.values().data().buffers()[0].clone()
);
assert_eq!(
Buffer::from(&[0i64, 3, 3].to_byte_slice()),
d.data().buffers()[0].clone()
);
let inner_list = d.value(0);
let inner_list = inner_list.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(3, inner_list.len());
assert_eq!(0, inner_list.null_count());
assert_eq!(inner_list, &Int32Array::from(vec![3, 4, 5]));
}
}