blob: 095ef85a884f476470d21282b42e70ae24785a4c [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow::{
array::{
downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, BinaryArray, BooleanBufferBuilder,
BufferBuilder, GenericByteArray, PrimitiveArray, StringArray,
},
buffer::{MutableBuffer, NullBuffer, OffsetBuffer},
datatypes::{ArrowNativeType, ByteArrayType},
record_batch::{RecordBatch, RecordBatchOptions},
};
use arrow_schema::DataType;
use datafusion::common::Result;
use crate::{downcast_any, prefetch_read_data};
pub fn take_batch<T: ArrowPrimitiveType>(
batch: RecordBatch,
indices: impl Into<PrimitiveArray<T>>,
) -> Result<RecordBatch> {
let indices = indices.into();
let taken_num_batch_rows = indices.len();
Ok(RecordBatch::try_new_with_options(
batch.schema(),
take_cols_internal(batch.columns(), &indices)?,
&RecordBatchOptions::new().with_row_count(Some(taken_num_batch_rows)),
)?)
}
pub fn take_cols<T: ArrowPrimitiveType>(
cols: &[ArrayRef],
indices: impl Into<PrimitiveArray<T>>,
) -> Result<Vec<ArrayRef>> {
take_cols_internal(cols, &indices.into())
}
fn take_cols_internal<T: ArrowPrimitiveType>(
cols: &[ArrayRef],
indices: &PrimitiveArray<T>,
) -> Result<Vec<ArrayRef>> {
cols.into_iter()
.map(|c| Ok(arrow::compute::take(&c, indices, None)?))
.collect::<Result<_>>()
}
pub type ArrayInterleaver = Box<dyn Fn(&[(usize, usize)]) -> Result<ArrayRef> + Send>;
pub type BatchInterleaver = Box<dyn Fn(&[(usize, usize)]) -> Result<RecordBatch> + Send>;
#[inline]
pub fn create_batch_interleaver(
batches: &[RecordBatch],
with_prefetching: bool,
) -> Result<BatchInterleaver> {
let batch_schema = batches[0].schema();
let mut col_arrays = vec![vec![]; batches[0].num_columns()];
for batch in batches {
for (col_idx, col) in batch.columns().iter().enumerate() {
col_arrays[col_idx].push(col.clone());
}
}
let col_interleavers = col_arrays
.iter()
.map(|arrays| create_array_interleaver(arrays, with_prefetching))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(move |indices| {
let cols = col_interleavers
.iter()
.map(|col_interleaver| col_interleaver(indices))
.collect::<Result<Vec<_>>>();
let batch = RecordBatch::try_new_with_options(
batch_schema.clone(),
cols?,
&RecordBatchOptions::new().with_row_count(Some(indices.len())),
)?;
Ok(batch)
}))
}
#[inline]
pub fn create_array_interleaver(
values: &[ArrayRef],
with_prefetching: bool,
) -> Result<ArrayInterleaver> {
struct Interleave<T> {
arrays: Vec<T>,
has_nulls: bool,
}
impl<T: Array> Interleave<T> {
fn new(arrays: Vec<T>) -> Self {
let has_nulls = arrays.iter().any(|v| v.null_count() > 0);
Self { arrays, has_nulls }
}
fn nulls(&self, indices: &[(usize, usize)]) -> Option<NullBuffer> {
let nulls = match self.has_nulls {
true => {
let mut builder = BooleanBufferBuilder::new(indices.len());
for (a, b) in indices {
let v = self.arrays[*a].is_valid(*b);
builder.append(v)
}
Some(NullBuffer::new(builder.finish()))
}
false => None,
};
nulls
}
}
#[inline]
fn interleave_primitive<T: ArrowPrimitiveType, const WITH_PREFETCHING: bool>(
interleaver: &Interleave<PrimitiveArray<T>>,
indices: &[(usize, usize)],
dt: &DataType,
) -> Result<ArrayRef> {
let nulls = interleaver.nulls(indices);
let mut values = Vec::with_capacity(indices.len());
for (i, (a, b)) in indices.iter().enumerate() {
if WITH_PREFETCHING {
const PREFETCH_AHEAD: usize = 4;
if i + PREFETCH_AHEAD < indices.len() {
let (pa, pb) = indices[i + PREFETCH_AHEAD];
prefetch_read_data!({
let array = interleaver.arrays.get_unchecked(pa);
let ptr = array.values().as_ptr().wrapping_add(array.offset() + pb);
ptr
});
}
}
let array = &interleaver.arrays[*a];
if array.is_valid(*b) {
let v = interleaver.arrays[*a].value(*b);
values.push(v)
} else {
values.push(Default::default());
}
}
let array = PrimitiveArray::<T>::new(values.into(), nulls);
Ok(Arc::new(array.with_data_type(dt.clone())))
}
#[inline]
fn interleave_bytes<T: ByteArrayType, const WITH_PREFETCHING: bool>(
interleaver: &Interleave<GenericByteArray<T>>,
indices: &[(usize, usize)],
) -> Result<ArrayRef> {
let nulls = interleaver.nulls(indices);
let mut capacity = 0;
let mut offsets = BufferBuilder::<T::Offset>::new(indices.len() + 1);
offsets.append(T::Offset::from_usize(0).unwrap());
for (i, (a, b)) in indices.iter().enumerate() {
if WITH_PREFETCHING {
const PREFETCH_AHEAD: usize = 4;
if i + PREFETCH_AHEAD < indices.len() {
let (pa, pb) = indices[i + PREFETCH_AHEAD];
prefetch_read_data!({
let array = interleaver.arrays.get_unchecked(pa);
array.value_offsets().get_unchecked(pb)
});
}
}
let array = &interleaver.arrays[*a];
if array.is_valid(*b) {
let o = array.value_offsets();
let element_len = o[*b + 1].as_usize() - o[*b].as_usize();
capacity += element_len;
}
offsets.append(T::Offset::from_usize(capacity).expect("overflow"));
}
let mut values = MutableBuffer::new(capacity);
for (i, (a, b)) in indices.iter().enumerate() {
if WITH_PREFETCHING {
const PREFETCH_AHEAD: usize = 4;
if i + PREFETCH_AHEAD < indices.len() {
let (pa, pb) = indices[i + PREFETCH_AHEAD];
prefetch_read_data!({
let array = interleaver.arrays.get_unchecked(pa);
let start = *array.value_offsets().get_unchecked(pb);
let ptr = array.values().as_ptr().wrapping_add(start.as_usize());
ptr
});
}
}
let array = &interleaver.arrays[*a];
if array.is_valid(*b) {
values.extend_from_slice(interleaver.arrays[*a].value(*b).as_ref());
}
}
// Safety: safe by construction
let array = unsafe {
let offsets = OffsetBuffer::new_unchecked(offsets.finish().into());
GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
};
Ok(Arc::new(array))
}
if !values.is_empty() {
let dt = values[0].data_type();
macro_rules! primitive_helper {
($t:ty, $dt:ident) => {{
let interleaver = Interleave::new(
values
.iter()
.map(|v| downcast_any!(v, PrimitiveArray<$t>).unwrap().clone())
.collect::<Vec<_>>(),
);
let dt = $dt.clone();
return Ok(Box::new(move |indices| {
if with_prefetching {
interleave_primitive::<_, true>(&interleaver, indices, &dt)
} else {
interleave_primitive::<_, false>(&interleaver, indices, &dt)
}
}));
}};
}
downcast_primitive! {
dt => (primitive_helper, dt),
DataType::Utf8 => {
let interleaver = Interleave::new(values
.iter()
.map(|v| downcast_any!(v, StringArray).unwrap().clone())
.collect::<Vec<_>>(),
);
return Ok(Box::new(move |indices| if with_prefetching {
interleave_bytes::<_, true>(&interleaver, indices)
} else {
interleave_bytes::<_, false>(&interleaver, indices)
}));
}
DataType::Binary => {
let interleaver = Interleave::new(values
.iter()
.map(|v| downcast_any!(v, BinaryArray).unwrap().clone())
.collect::<Vec<_>>(),
);
return Ok(Box::new(move |indices| if with_prefetching {
interleave_bytes::<_, true>(&interleaver, indices)
} else {
interleave_bytes::<_, false>(&interleaver, indices)
}));
}
_ => {},
}
}
// fallback to arrow's implementation
let values_cloned = values.to_vec();
Ok(Box::new(move |indices| {
let value_refs = values_cloned.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
Ok(arrow::compute::interleave(&value_refs, indices)?)
}))
}