Speed up bound checking in `take` (#281)
* WIP improve take performance
* WIP
* Bound checking speed
* Simplify
* fmt
* Improve formatting
diff --git a/arrow/benches/take_kernels.rs b/arrow/benches/take_kernels.rs
index 2853eb5..b1d03d7 100644
--- a/arrow/benches/take_kernels.rs
+++ b/arrow/benches/take_kernels.rs
@@ -23,7 +23,7 @@
extern crate arrow;
-use arrow::compute::take;
+use arrow::compute::{take, TakeOptions};
use arrow::datatypes::*;
use arrow::util::test_util::seedable_rng;
use arrow::{array::*, util::bench_util::*};
@@ -46,6 +46,12 @@
criterion::black_box(take(values, &indices, None).unwrap());
}
+fn bench_take_bounds_check(values: &dyn Array, indices: &UInt32Array) {
+ criterion::black_box(
+ take(values, &indices, Some(TakeOptions { check_bounds: true })).unwrap(),
+ );
+}
+
fn add_benchmark(c: &mut Criterion) {
let values = create_primitive_array::<Int32Type>(512, 0.0);
let indices = create_random_index(512, 0.0);
@@ -56,6 +62,17 @@
b.iter(|| bench_take(&values, &indices))
});
+ let values = create_primitive_array::<Int32Type>(512, 0.0);
+ let indices = create_random_index(512, 0.0);
+ c.bench_function("take check bounds i32 512", |b| {
+ b.iter(|| bench_take_bounds_check(&values, &indices))
+ });
+ let values = create_primitive_array::<Int32Type>(1024, 0.0);
+ let indices = create_random_index(1024, 0.0);
+ c.bench_function("take check bounds i32 1024", |b| {
+ b.iter(|| bench_take_bounds_check(&values, &indices))
+ });
+
let indices = create_random_index(512, 0.5);
c.bench_function("take i32 nulls 512", |b| {
b.iter(|| bench_take(&values, &indices))
diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs
index 0217573..d325ce4 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -100,17 +100,30 @@
let options = options.unwrap_or_default();
if options.check_bounds {
let len = values.len();
- for i in 0..indices.len() {
- if indices.is_valid(i) {
- let ix = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
+ if indices.null_count() > 0 {
+ indices.iter().flatten().try_for_each(|index| {
+ let ix = ToPrimitive::to_usize(&index).ok_or_else(|| {
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;
if ix >= len {
return Err(ArrowError::ComputeError(
- format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len))
- );
+ format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len))
+ );
}
- }
+ Ok(())
+ })?;
+ } else {
+ indices.values().iter().try_for_each(|index| {
+ let ix = ToPrimitive::to_usize(index).ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
+ if ix >= len {
+ return Err(ArrowError::ComputeError(
+ format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len))
+ );
+ }
+ Ok(())
+ })?
}
}
match values.data_type() {