blob: c5efaddd971b4c6473c38e5b14f5ccbc04684f06 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Defines miscellaneous array kernels.
use crate::buffer::buffer_bin_and;
use crate::datatypes::DataType;
use crate::error::Result;
use crate::record_batch::RecordBatch;
use crate::{array::*, util::bit_chunk_iterator::BitChunkIterator};
use std::iter::Enumerate;
/// Function that can filter arbitrary arrays
pub type Filter<'a> = Box<Fn(&ArrayData) -> ArrayData + 'a>;
/// Internal state of [SlicesIterator]
#[derive(Debug, PartialEq)]
enum State {
// it is iterating over bits of a mask (`u64`, steps of size of 1 slot)
Bits(u64),
// it is iterating over chunks (steps of size of 64 slots)
Chunks,
// it is iterating over the remainding bits (steps of size of 1 slot)
Remainder,
// nothing more to iterate.
Finish,
}
/// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose
/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be
/// "taken" from an array to be filtered.
#[derive(Debug)]
pub struct SlicesIterator<'a> {
iter: Enumerate<BitChunkIterator<'a>>,
state: State,
filter: &'a BooleanArray,
remainder_mask: u64,
remainder_len: usize,
chunk_len: usize,
len: usize,
start: usize,
on_region: bool,
current_chunk: usize,
current_bit: usize,
}
impl<'a> SlicesIterator<'a> {
pub fn new(filter: &'a BooleanArray) -> Self {
let values = &filter.data_ref().buffers()[0];
let chunks = values.bit_chunks(filter.offset(), filter.len());
Self {
iter: chunks.iter().enumerate(),
state: State::Chunks,
filter,
remainder_len: chunks.remainder_len(),
chunk_len: chunks.chunk_len(),
remainder_mask: chunks.remainder_bits(),
len: 0,
start: 0,
on_region: false,
current_chunk: 0,
current_bit: 0,
}
}
/// Counts the number of set bits in the filter array.
fn filter_count(&self) -> usize {
let values = self.filter.values();
values.count_set_bits_offset(self.filter.offset(), self.filter.len())
}
#[inline]
fn current_start(&self) -> usize {
self.current_chunk * 64 + self.current_bit
}
#[inline]
fn iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)> {
while self.current_bit < max {
if (mask & (1 << self.current_bit)) != 0 {
if !self.on_region {
self.start = self.current_start();
self.on_region = true;
}
self.len += 1;
} else if self.on_region {
let result = (self.start, self.start + self.len);
self.len = 0;
self.on_region = false;
self.current_bit += 1;
return Some(result);
}
self.current_bit += 1;
}
self.current_bit = 0;
None
}
/// iterates over chunks.
#[inline]
fn iterate_chunks(&mut self) -> Option<(usize, usize)> {
while let Some((i, mask)) = self.iter.next() {
self.current_chunk = i;
if mask == 0 {
if self.on_region {
let result = (self.start, self.start + self.len);
self.len = 0;
self.on_region = false;
return Some(result);
}
} else if mask == 18446744073709551615u64 {
// = !0u64
if !self.on_region {
self.start = self.current_start();
self.on_region = true;
}
self.len += 64;
} else {
// there is a chunk that has a non-trivial mask => iterate over bits.
self.state = State::Bits(mask);
return None;
}
}
// no more chunks => start iterating over the remainder
self.current_chunk = self.chunk_len;
self.state = State::Remainder;
None
}
}
impl<'a> Iterator for SlicesIterator<'a> {
type Item = (usize, usize);
fn next(&mut self) -> Option<Self::Item> {
match self.state {
State::Chunks => {
match self.iterate_chunks() {
None => {
// iterating over chunks does not yield any new slice => continue to the next
self.current_bit = 0;
self.next()
}
other => other,
}
}
State::Bits(mask) => {
match self.iterate_bits(mask, 64) {
None => {
// iterating over bits does not yield any new slice => change back
// to chunks and continue to the next
self.state = State::Chunks;
self.next()
}
other => other,
}
}
State::Remainder => {
match self.iterate_bits(self.remainder_mask, self.remainder_len) {
None => {
self.state = State::Finish;
if self.on_region {
Some((self.start, self.start + self.len))
} else {
None
}
}
other => other,
}
}
State::Finish => None,
}
}
}
/// Returns a prepared function optimized to filter multiple arrays.
/// Creating this function requires time, but using it is faster than [filter] when the
/// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`).
/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
/// Therefore, it is considered undefined behavior to pass `filter` with null values.
pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
let iter = SlicesIterator::new(filter);
let filter_count = iter.filter_count();
let chunks = iter.collect::<Vec<_>>();
Ok(Box::new(move |array: &ArrayData| {
match filter_count {
// return all
len if len == array.len() => array.clone(),
0 => ArrayData::new_empty(array.data_type()),
_ => {
let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
chunks
.iter()
.for_each(|(start, end)| mutable.extend(0, *start, *end));
mutable.freeze()
}
}
}))
}
/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
let array_data = filter.data_ref();
let null_bitmap = array_data.null_buffer().unwrap();
let mask = filter.values();
let offset = filter.offset();
let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());
let array_data = ArrayData::builder(DataType::Boolean)
.len(filter.len())
.add_buffer(new_mask)
.build();
BooleanArray::from(array_data)
}
/// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
///
/// # Example
/// ```rust
/// # use arrow::array::{Int32Array, BooleanArray};
/// # use arrow::error::Result;
/// # use arrow::compute::kernels::filter::filter;
/// # fn main() -> Result<()> {
/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
/// let c = filter(&array, &filter_array)?;
/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
/// # Ok(())
/// # }
/// ```
pub fn filter(array: &Array, predicate: &BooleanArray) -> Result<ArrayRef> {
if predicate.null_count() > 0 {
// this greatly simplifies subsequent filtering code
// now we only have a boolean mask to deal with
let predicate = prep_null_mask_filter(predicate);
return filter(array, &predicate);
}
let iter = SlicesIterator::new(predicate);
let filter_count = iter.filter_count();
match filter_count {
0 => {
// return empty
Ok(new_empty_array(array.data_type()))
}
len if len == array.len() => {
// return all
let data = array.data().clone();
Ok(make_array(data))
}
_ => {
// actually filter
let mut mutable =
MutableArrayData::new(vec![array.data_ref()], false, filter_count);
iter.for_each(|(start, end)| mutable.extend(0, start, end));
let data = mutable.freeze();
Ok(make_array(data))
}
}
}
/// Returns a new [RecordBatch] with arrays containing only values matching the filter.
pub fn filter_record_batch(
record_batch: &RecordBatch,
predicate: &BooleanArray,
) -> Result<RecordBatch> {
if predicate.null_count() > 0 {
// this greatly simplifies subsequent filtering code
// now we only have a boolean mask to deal with
let predicate = prep_null_mask_filter(predicate);
return filter_record_batch(record_batch, &predicate);
}
let filter = build_filter(predicate)?;
let filtered_arrays = record_batch
.columns()
.iter()
.map(|a| make_array(filter(&a.data())))
.collect();
RecordBatch::try_new(record_batch.schema(), filtered_arrays)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datatypes::Int64Type;
use crate::{
buffer::Buffer,
datatypes::{DataType, Field},
};
macro_rules! def_temporal_test {
($test:ident, $array_type: ident, $data: expr) => {
#[test]
fn $test() {
let a = $data;
let b = BooleanArray::from(vec![true, false, true, false]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
assert_eq!(2, d.len());
assert_eq!(1, d.value(0));
assert_eq!(3, d.value(1));
}
};
}
def_temporal_test!(
test_filter_date32,
Date32Array,
Date32Array::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_date64,
Date64Array,
Date64Array::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time32_second,
Time32SecondArray,
Time32SecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time32_millisecond,
Time32MillisecondArray,
Time32MillisecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time64_microsecond,
Time64MicrosecondArray,
Time64MicrosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_time64_nanosecond,
Time64NanosecondArray,
Time64NanosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_second,
DurationSecondArray,
DurationSecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_millisecond,
DurationMillisecondArray,
DurationMillisecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_microsecond,
DurationMicrosecondArray,
DurationMicrosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_duration_nanosecond,
DurationNanosecondArray,
DurationNanosecondArray::from(vec![1, 2, 3, 4])
);
def_temporal_test!(
test_filter_timestamp_second,
TimestampSecondArray,
TimestampSecondArray::from_vec(vec![1, 2, 3, 4], None)
);
def_temporal_test!(
test_filter_timestamp_millisecond,
TimestampMillisecondArray,
TimestampMillisecondArray::from_vec(vec![1, 2, 3, 4], None)
);
def_temporal_test!(
test_filter_timestamp_microsecond,
TimestampMicrosecondArray,
TimestampMicrosecondArray::from_vec(vec![1, 2, 3, 4], None)
);
def_temporal_test!(
test_filter_timestamp_nanosecond,
TimestampNanosecondArray,
TimestampNanosecondArray::from_vec(vec![1, 2, 3, 4], None)
);
#[test]
fn test_filter_array_slice() {
let a_slice = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
let a = a_slice.as_ref();
let b = BooleanArray::from(vec![true, false, false, true]);
// filtering with sliced filter array is not currently supported
// let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
// let b = b_slice.as_any().downcast_ref().unwrap();
let c = filter(a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, d.len());
assert_eq!(6, d.value(0));
assert_eq!(9, d.value(1));
}
#[test]
fn test_filter_array_low_density() {
// this test exercises the all 0's branch of the filter algorithm
let mut data_values = (1..=65).collect::<Vec<i32>>();
let mut filter_values =
(1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
// set up two more values after the batch
data_values.extend_from_slice(&[66, 67]);
filter_values.extend_from_slice(&[false, true]);
let a = Int32Array::from(data_values);
let b = BooleanArray::from(filter_values);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, d.len());
assert_eq!(65, d.value(0));
assert_eq!(67, d.value(1));
}
#[test]
fn test_filter_array_high_density() {
// this test exercises the all 1's branch of the filter algorithm
let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
let mut filter_values = (1..=65)
.map(|i| !matches!(i % 65, 0))
.collect::<Vec<bool>>();
// set second data value to null
data_values[1] = None;
// set up two more values after the batch
data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
filter_values.extend_from_slice(&[false, true, true, true]);
let a = Int32Array::from(data_values);
let b = BooleanArray::from(filter_values);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(67, d.len());
assert_eq!(3, d.null_count());
assert_eq!(1, d.value(0));
assert_eq!(true, d.is_null(1));
assert_eq!(64, d.value(63));
assert_eq!(true, d.is_null(64));
assert_eq!(67, d.value(65));
}
#[test]
fn test_filter_string_array_simple() {
let a = StringArray::from(vec!["hello", " ", "world", "!"]);
let b = BooleanArray::from(vec![true, false, true, false]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!("hello", d.value(0));
assert_eq!("world", d.value(1));
}
#[test]
fn test_filter_primative_array_with_null() {
let a = Int32Array::from(vec![Some(5), None]);
let b = BooleanArray::from(vec![false, true]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(1, d.len());
assert_eq!(true, d.is_null(0));
}
#[test]
fn test_filter_string_array_with_null() {
let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
let b = BooleanArray::from(vec![true, false, false, true]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!("hello", d.value(0));
assert_eq!(false, d.is_null(0));
assert_eq!(true, d.is_null(1));
}
#[test]
fn test_filter_binary_array_with_null() {
let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
let a = BinaryArray::from(data);
let b = BooleanArray::from(vec![true, false, false, true]);
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!(b"hello", d.value(0));
assert_eq!(false, d.is_null(0));
assert_eq!(true, d.is_null(1));
}
#[test]
fn test_filter_array_slice_with_null() {
let a_slice =
Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
let a = a_slice.as_ref();
let b = BooleanArray::from(vec![true, false, false, true]);
// filtering with sliced filter array is not currently supported
// let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
// let b = b_slice.as_any().downcast_ref().unwrap();
let c = filter(a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, d.len());
assert_eq!(true, d.is_null(0));
assert_eq!(false, d.is_null(1));
assert_eq!(9, d.value(1));
}
#[test]
fn test_filter_dictionary_array() {
let values = vec![Some("hello"), None, Some("world"), Some("!")];
let a: Int8DictionaryArray = values.iter().copied().collect();
let b = BooleanArray::from(vec![false, true, true, false]);
let c = filter(&a, &b).unwrap();
let d = c
.as_ref()
.as_any()
.downcast_ref::<Int8DictionaryArray>()
.unwrap();
let value_array = d.values();
let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
// values are cloned in the filtered dictionary array
assert_eq!(3, values.len());
// but keys are filtered
assert_eq!(2, d.len());
assert_eq!(true, d.is_null(0));
assert_eq!("world", values.value(d.keys().value(1) as usize));
}
#[test]
fn test_filter_string_array_with_negated_boolean_array() {
let a = StringArray::from(vec!["hello", " ", "world", "!"]);
let mut bb = BooleanBuilder::new(2);
bb.append_value(false).unwrap();
bb.append_value(true).unwrap();
bb.append_value(false).unwrap();
bb.append_value(true).unwrap();
let b = bb.finish();
let b = crate::compute::not(&b).unwrap();
let c = filter(&a, &b).unwrap();
let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(2, d.len());
assert_eq!("hello", d.value(0));
assert_eq!("world", d.value(1));
}
#[test]
fn test_filter_list_array() {
let value_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7]))
.build();
let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 6, 8, 8]);
let list_data_type =
DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false)));
let list_data = ArrayData::builder(list_data_type)
.len(4)
.add_buffer(value_offsets)
.add_child_data(value_data)
.null_bit_buffer(Buffer::from([0b00000111]))
.build();
// a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
let a = LargeListArray::from(list_data);
let b = BooleanArray::from(vec![false, true, false, true]);
let result = filter(&a, &b).unwrap();
// expected: [[3, 4, 5], null]
let value_data = ArrayData::builder(DataType::Int32)
.len(3)
.add_buffer(Buffer::from_slice_ref(&[3, 4, 5]))
.build();
let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 3]);
let list_data_type =
DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false)));
let expected = ArrayData::builder(list_data_type)
.len(2)
.add_buffer(value_offsets)
.add_child_data(value_data)
.null_bit_buffer(Buffer::from([0b00000001]))
.build();
assert_eq!(&make_array(expected), &result);
}
#[test]
fn test_slice_iterator_bits() {
let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
let filter = BooleanArray::from(filter_values);
let iter = SlicesIterator::new(&filter);
let filter_count = iter.filter_count();
let chunks = iter.collect::<Vec<_>>();
assert_eq!(chunks, vec![(1, 2)]);
assert_eq!(filter_count, 1);
}
#[test]
fn test_slice_iterator_bits1() {
let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
let filter = BooleanArray::from(filter_values);
let iter = SlicesIterator::new(&filter);
let filter_count = iter.filter_count();
let chunks = iter.collect::<Vec<_>>();
assert_eq!(chunks, vec![(0, 1), (2, 64)]);
assert_eq!(filter_count, 64 - 1);
}
#[test]
fn test_slice_iterator_chunk_and_bits() {
let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
let filter = BooleanArray::from(filter_values);
let iter = SlicesIterator::new(&filter);
let filter_count = iter.filter_count();
let chunks = iter.collect::<Vec<_>>();
assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
assert_eq!(filter_count, 61 + 61 + 5);
}
#[test]
fn test_null_mask() -> Result<()> {
use crate::compute::kernels::comparison;
let a: PrimitiveArray<Int64Type> =
PrimitiveArray::from(vec![Some(1), Some(2), None]);
let mask0 = comparison::eq(&a, &a)?;
let out0 = filter(&a, &mask0)?;
let out_arr0 = out0
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap();
let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
let out1 = filter(&a, &mask1)?;
let out_arr1 = out1
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap();
assert_eq!(mask0, mask1);
assert_eq!(out_arr0, out_arr1);
Ok(())
}
#[test]
fn test_fast_path() -> Result<()> {
let a: PrimitiveArray<Int64Type> =
PrimitiveArray::from(vec![Some(1), Some(2), None]);
// all true
let mask = BooleanArray::from(vec![true, true, true]);
let out = filter(&a, &mask)?;
let b = out
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap();
assert_eq!(&a, b);
// all false
let mask = BooleanArray::from(vec![false, false, false]);
let out = filter(&a, &mask)?;
assert_eq!(out.len(), 0);
assert_eq!(out.data_type(), &DataType::Int64);
Ok(())
}
}