blob: a5e789cd545ddac055c598149a394c1d53e2218d [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::{
array::{ArrayRef, PrimitiveArray, UInt32Array},
datatypes::SchemaRef,
error::Result as ArrowResult,
record_batch::{RecordBatch, RecordBatchOptions},
};
use datafusion::common::Result;
pub fn take_batch<T: num::PrimInt>(
batch: RecordBatch,
indices: impl IntoIterator<Item = T>,
) -> Result<RecordBatch> {
let indices: UInt32Array =
PrimitiveArray::from_iter(indices.into_iter().map(|idx| idx.to_u32().unwrap()));
take_batch_internal(batch, indices)
}
pub fn take_batch_opt<T: num::PrimInt>(
batch: RecordBatch,
indices: impl IntoIterator<Item = Option<T>>,
) -> Result<RecordBatch> {
let indices: UInt32Array = PrimitiveArray::from_iter(
indices
.into_iter()
.map(|opt| opt.map(|idx| idx.to_u32().unwrap())),
);
take_batch_internal(batch, indices)
}
pub fn take_cols<T: num::PrimInt>(
cols: &[ArrayRef],
indices: impl IntoIterator<Item = T>,
) -> Result<Vec<ArrayRef>> {
let indices: UInt32Array =
PrimitiveArray::from_iter(indices.into_iter().map(|idx| idx.to_u32().unwrap()));
take_cols_internal(cols, &indices)
}
pub fn take_cols_opt<T: num::PrimInt>(
cols: &[ArrayRef],
indices: impl IntoIterator<Item = Option<T>>,
) -> Result<Vec<ArrayRef>> {
let indices: UInt32Array = PrimitiveArray::from_iter(
indices
.into_iter()
.map(|opt| opt.map(|idx| idx.to_u32().unwrap())),
);
take_cols_internal(cols, &indices)
}
fn take_batch_internal(batch: RecordBatch, indices: UInt32Array) -> Result<RecordBatch> {
let taken_num_batch_rows = indices.len();
let schema = batch.schema();
let cols = batch.columns();
let cols = take_cols_internal(cols, &indices)?;
drop(indices);
let taken = RecordBatch::try_new_with_options(
schema,
cols,
&RecordBatchOptions::new().with_row_count(Some(taken_num_batch_rows)),
)?;
Ok(taken)
}
fn take_cols_internal(cols: &[ArrayRef], indices: &UInt32Array) -> Result<Vec<ArrayRef>> {
let cols = cols
.into_iter()
.map(|c| Ok(arrow::compute::take(&c, indices, None)?))
.collect::<Result<_>>()?;
Ok(cols)
}
pub fn interleave_batches(
schema: SchemaRef,
batches: &[RecordBatch],
indices: &[(usize, usize)],
) -> Result<RecordBatch> {
let mut batches_arrays: Vec<Vec<ArrayRef>> = schema
.fields()
.iter()
.map(|_| Vec::with_capacity(batches.len()))
.collect();
for batch in batches {
for (col_idx, column) in batch.columns().iter().enumerate() {
batches_arrays[col_idx].push(column.clone());
}
}
Ok(RecordBatch::try_new_with_options(
schema.clone(),
batches_arrays
.iter()
.map(|arrays| {
arrow::compute::interleave(
&arrays
.iter()
.map(|array| array.as_ref())
.collect::<Vec<_>>(),
indices,
)
})
.collect::<ArrowResult<Vec<_>>>()?,
&RecordBatchOptions::new().with_row_count(Some(indices.len())),
)?)
}