blob: a811493839387a92a6b961f7376509ce489b120f [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
fmt::{Debug, Formatter},
sync::Arc,
};
use arrow::{
array::{Array, ArrayRef, BinaryArray, BinaryBuilder},
datatypes::{DataType, Field, Fields, Schema, SchemaRef},
record_batch::{RecordBatch, RecordBatchOptions},
row::{RowConverter, Rows, SortField},
};
use blaze_jni_bridge::{
conf,
conf::{DoubleConf, IntConf},
};
use datafusion::{
common::{cast::as_binary_array, Result},
physical_expr::PhysicalExprRef,
};
use datafusion_ext_commons::df_execution_err;
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use crate::{
agg::{
acc::{
create_acc_from_initial_value, create_dyn_loaders_from_initial_value,
create_dyn_savers_from_initial_value, AccumInitialValue, AccumStateRow, LoadFn,
OwnedAccumStateRow, RefAccumStateRow, SaveFn,
},
Agg, AggExecMode, AggExpr, AggMode, GroupingExpr, AGG_BUF_COLUMN_NAME,
},
common::cached_exprs_evaluator::CachedExprsEvaluator,
};
pub struct AggContext {
pub exec_mode: AggExecMode,
pub need_partial_update: bool,
pub need_partial_merge: bool,
pub need_final_merge: bool,
pub need_partial_update_aggs: Vec<(usize, Arc<dyn Agg>)>,
pub need_partial_merge_aggs: Vec<(usize, Arc<dyn Agg>)>,
pub input_schema: SchemaRef,
pub grouping_schema: SchemaRef,
pub agg_schema: SchemaRef,
pub output_schema: SchemaRef,
pub grouping_row_converter: Arc<Mutex<RowConverter>>,
pub groupings: Vec<GroupingExpr>,
pub aggs: Vec<AggExpr>,
pub initial_acc: OwnedAccumStateRow,
pub initial_input_acc: OwnedAccumStateRow,
pub initial_input_buffer_offset: usize,
pub supports_partial_skipping: bool,
pub partial_skipping_ratio: f64,
pub partial_skipping_min_rows: usize,
pub agg_expr_evaluator: CachedExprsEvaluator,
pub acc_dyn_loaders: Vec<LoadFn>,
pub acc_dyn_savers: Vec<SaveFn>,
}
impl Debug for AggContext {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[groupings={:?}, aggs={:?}]", self.groupings, self.aggs,)
}
}
impl AggContext {
pub fn try_new(
exec_mode: AggExecMode,
input_schema: SchemaRef,
groupings: Vec<GroupingExpr>,
mut aggs: Vec<AggExpr>,
initial_input_buffer_offset: usize,
supports_partial_skipping: bool,
) -> Result<Self> {
let grouping_schema = Arc::new(Schema::new(
groupings
.iter()
.map(|grouping: &GroupingExpr| {
Ok(Field::new(
grouping.field_name.as_str(),
grouping.expr.data_type(&input_schema)?,
grouping.expr.nullable(&input_schema)?,
))
})
.collect::<Result<Fields>>()?,
));
let grouping_row_converter = Arc::new(Mutex::new(RowConverter::new(
grouping_schema
.fields()
.iter()
.map(|field| SortField::new(field.data_type().clone()))
.collect(),
)?));
// final aggregates may not exist along with partial/partial-merge
let need_partial_update = aggs.iter().any(|agg| agg.mode == AggMode::Partial);
let need_partial_merge = aggs.iter().any(|agg| agg.mode != AggMode::Partial);
let need_final_merge = aggs.iter().any(|agg| agg.mode == AggMode::Final);
assert!(!(need_final_merge && aggs.iter().any(|agg| agg.mode != AggMode::Final)));
let need_partial_update_aggs: Vec<(usize, Arc<dyn Agg>)> = aggs
.iter()
.enumerate()
.filter(|(_idx, agg)| agg.mode.is_partial())
.map(|(idx, agg)| (idx, agg.agg.clone()))
.collect();
let need_partial_merge_aggs: Vec<(usize, Arc<dyn Agg>)> = aggs
.iter()
.enumerate()
.filter(|(_idx, agg)| !agg.mode.is_partial())
.map(|(idx, agg)| (idx, agg.agg.clone()))
.collect();
let mut agg_fields = vec![];
if need_final_merge {
for agg in &aggs {
agg_fields.push(Field::new(
&agg.field_name,
agg.agg.data_type().clone(),
agg.agg.nullable(),
));
}
} else {
agg_fields.push(Field::new(AGG_BUF_COLUMN_NAME, DataType::Binary, false));
}
let agg_schema = Arc::new(Schema::new(agg_fields));
let output_schema = Arc::new(Schema::new(
[
grouping_schema.fields().to_vec(),
agg_schema.fields().to_vec(),
]
.concat(),
));
let initial_accums: Box<[AccumInitialValue]> = aggs
.iter()
.flat_map(|agg: &AggExpr| agg.agg.accums_initial())
.cloned()
.collect();
let (initial_acc, accum_state_val_addrs) = create_acc_from_initial_value(&initial_accums)?;
let acc_dyn_loaders = create_dyn_loaders_from_initial_value(&initial_accums)?;
let acc_dyn_savers = create_dyn_savers_from_initial_value(&initial_accums)?;
// in distinct aggregrations, partial and partial-merge may happen at the same
// time, i.e:
//
// Agg [groupings=[], aggs=[
// AggExpr { field_name: "#747", mode: PartialMerge, agg: Count(...) },
// AggExpr { field_name: "#748", mode: Partial, agg: Count(Column { name:
// "#640", index: 0 }) } ]]
// Agg [groupings=[GroupingExpr { field_name: "#640", ...], aggs=[
// AggExpr { field_name: "#747", mode: PartialMerge, agg: Count(...) }
// ]]
//
// in this situation, the processing acc has more fields than input. so we
// need to maintain a standalone acc for the input.
// the addrs is not used because the extra fields are always in the last. the
// processing addrs can be reused.
let initial_input_accums: Box<[AccumInitialValue]> = need_partial_merge_aggs
.iter()
.flat_map(|(_, agg)| agg.accums_initial())
.cloned()
.collect();
let (initial_input_acc, _input_accum_state_val_addrs) =
create_acc_from_initial_value(&initial_input_accums)?;
let mut offset = 0;
for agg in &mut aggs {
let len = agg.agg.accums_initial().len();
unsafe {
// safety: accum_state_val_addrs is guaranteed not to be used at this time
Arc::get_mut_unchecked(&mut agg.agg)
.set_accum_state_val_addrs(&accum_state_val_addrs[offset..][..len]);
}
offset += len;
}
let agg_exprs_flatten: Vec<PhysicalExprRef> = aggs
.iter()
.filter(|agg| agg.mode.is_partial())
.flat_map(|agg| agg.agg.exprs())
.collect();
let agg_expr_evaluator_output_schema = Arc::new(Schema::new(
agg_exprs_flatten
.iter()
.map(|e| {
Ok(Field::new(
"",
e.data_type(&input_schema)?,
e.nullable(&input_schema)?,
))
})
.collect::<Result<Fields>>()?,
));
let agg_expr_evaluator = CachedExprsEvaluator::try_new(
vec![],
agg_exprs_flatten,
agg_expr_evaluator_output_schema,
)?;
let (partial_skipping_ratio, partial_skipping_min_rows) = if supports_partial_skipping {
(
conf::PARTIAL_AGG_SKIPPING_RATIO.value()?,
conf::PARTIAL_AGG_SKIPPING_MIN_ROWS.value()? as usize,
)
} else {
Default::default()
};
Ok(Self {
exec_mode,
need_partial_update,
need_partial_merge,
need_final_merge,
need_partial_update_aggs,
need_partial_merge_aggs,
input_schema,
output_schema,
grouping_schema,
grouping_row_converter,
agg_schema,
groupings,
aggs,
initial_acc,
initial_input_acc,
acc_dyn_loaders,
acc_dyn_savers,
agg_expr_evaluator,
initial_input_buffer_offset,
supports_partial_skipping,
partial_skipping_ratio,
partial_skipping_min_rows,
})
}
pub fn create_grouping_rows(&self, input_batch: &RecordBatch) -> Result<Rows> {
let grouping_arrays: Vec<ArrayRef> = self
.groupings
.iter()
.map(|grouping| grouping.expr.evaluate(&input_batch))
.map(|r| r.and_then(|columnar| columnar.into_array(input_batch.num_rows())))
.collect::<Result<_>>()
.map_err(|err| err.context("agg: evaluating grouping arrays error"))?;
Ok(self
.grouping_row_converter
.lock()
.convert_columns(&grouping_arrays)?)
}
pub fn create_input_arrays(&self, input_batch: &RecordBatch) -> Result<Vec<Vec<ArrayRef>>> {
if !self.need_partial_update {
return Ok(vec![]);
}
let agg_exprs_batch = self.agg_expr_evaluator.filter_project(input_batch)?;
let mut input_arrays = Vec::with_capacity(self.aggs.len());
let mut offset = 0;
for agg in &self.aggs {
if agg.mode.is_partial() {
let num_agg_exprs = agg.agg.exprs().len();
let prepared = agg
.agg
.prepare_partial_args(&agg_exprs_batch.columns()[offset..][..num_agg_exprs])?;
input_arrays.push(prepared);
offset += num_agg_exprs;
} else {
input_arrays.push(vec![]);
}
}
Ok(input_arrays)
}
pub fn get_input_acc_array<'a>(&self, input_batch: &'a RecordBatch) -> Result<&'a BinaryArray> {
if self.need_partial_merge {
as_binary_array(input_batch.columns().last().unwrap())
} else {
static EMPTY_BINARY_ARRAY: OnceCell<BinaryArray> = OnceCell::new();
Ok(EMPTY_BINARY_ARRAY.get_or_init(|| BinaryArray::from_iter_values([[]; 0])))
}
}
pub fn build_agg_columns(
&self,
mut records: Vec<(impl AsRef<[u8]>, RefAccumStateRow)>,
) -> Result<Vec<ArrayRef>> {
let mut agg_columns = vec![];
if self.need_final_merge {
// output final merged value
let mut accs = records.into_iter().map(|(_, acc)| acc).collect::<Vec<_>>();
for agg in self.aggs.iter() {
let values = agg.agg.final_batch_merge(&mut accs)?;
agg_columns.push(values);
}
} else {
// output acc as a binary column
let mut binary_array = BinaryBuilder::with_capacity(records.len(), 0);
for (_, acc) in &mut records {
let acc_bytes = acc.save_to_bytes(&self.acc_dyn_savers)?;
binary_array.append_value(acc_bytes);
}
agg_columns.push(Arc::new(binary_array.finish()));
}
Ok(agg_columns)
}
pub fn convert_records_to_batch(
&self,
records: Vec<(impl AsRef<[u8]>, RefAccumStateRow)>,
) -> Result<RecordBatch> {
let row_count = records.len();
let grouping_row_converter = self.grouping_row_converter.lock();
let grouping_row_parser = grouping_row_converter.parser();
let grouping_columns = grouping_row_converter.convert_rows(
records
.iter()
.map(|(key, _)| grouping_row_parser.parse(key.as_ref())),
)?;
let agg_columns = self.build_agg_columns(records)?;
Ok(RecordBatch::try_new_with_options(
self.output_schema.clone(),
[grouping_columns, agg_columns].concat(),
&RecordBatchOptions::new().with_row_count(Some(row_count)),
)?)
}
pub fn partial_update_input(
&self,
acc: &mut RefAccumStateRow,
input_arrays: &[Vec<ArrayRef>],
row_idx: usize,
) -> Result<()> {
if self.need_partial_update {
for (idx, agg) in &self.need_partial_update_aggs {
agg.partial_update(acc, &input_arrays[*idx], row_idx)?;
}
}
Ok(())
}
pub fn partial_batch_update_input(
&self,
accs: &mut [RefAccumStateRow],
input_arrays: &[Vec<ArrayRef>],
) -> Result<()> {
if self.need_partial_update {
for (idx, agg) in &self.need_partial_update_aggs {
agg.partial_batch_update(accs, &input_arrays[*idx])?;
}
}
Ok(())
}
pub fn partial_update_input_all(
&self,
acc: &mut RefAccumStateRow,
input_arrays: &[Vec<ArrayRef>],
) -> Result<()> {
if self.need_partial_update {
for (idx, agg) in &self.need_partial_update_aggs {
agg.partial_update_all(acc, &input_arrays[*idx])?;
}
}
Ok(())
}
pub fn partial_merge_input(
&self,
acc: &mut RefAccumStateRow,
acc_array: &BinaryArray,
row_idx: usize,
) -> Result<()> {
if self.need_partial_merge {
let mut input_acc = self.initial_input_acc.clone();
input_acc.load_from_bytes(acc_array.value(row_idx), &self.acc_dyn_loaders)?;
for (_, agg) in &self.need_partial_merge_aggs {
agg.increase_acc_mem_used(&mut input_acc.as_mut());
agg.partial_merge(acc, &mut input_acc.as_mut())?;
}
}
Ok(())
}
pub fn partial_batch_merge_input(
&self,
accs: &mut [RefAccumStateRow],
acc_array: &BinaryArray,
) -> Result<()> {
if self.need_partial_merge {
let mut input_accs = acc_array
.iter()
.map(|value| {
let mut input_acc = self.initial_input_acc.clone();
input_acc.load_from_bytes(value.unwrap(), &self.acc_dyn_loaders)?;
Ok(input_acc)
})
.collect::<Result<Vec<_>>>()?;
let mut input_ref_accs = input_accs
.iter_mut()
.map(|acc| acc.as_mut())
.collect::<Vec<_>>();
for (_, agg) in &self.need_partial_merge_aggs {
for input_acc in &mut input_ref_accs {
agg.increase_acc_mem_used(input_acc);
}
agg.partial_batch_merge(accs, &mut input_ref_accs)?;
}
}
Ok(())
}
pub fn partial_merge_input_all(
&self,
acc: &mut RefAccumStateRow,
acc_array: &BinaryArray,
) -> Result<()> {
if self.need_partial_merge {
let mut input_acc = self.initial_input_acc.clone();
for row_idx in 0..acc_array.len() {
input_acc.load_from_bytes(acc_array.value(row_idx), &self.acc_dyn_loaders)?;
for (_, agg) in &self.need_partial_merge_aggs {
agg.increase_acc_mem_used(&mut input_acc.as_mut());
agg.partial_merge(acc, &mut input_acc.as_mut())?;
}
}
}
Ok(())
}
pub fn partial_merge(
&self,
acc: &mut RefAccumStateRow,
merging_acc: &mut RefAccumStateRow,
) -> Result<()> {
for agg in &self.aggs {
agg.agg
.partial_merge(acc, merging_acc)
.or_else(|err| df_execution_err!("agg: executing partial_merge() error: {err}"))?;
}
Ok(())
}
pub fn acc_dyn_mem_used(&self) -> usize {
self.aggs
.iter()
.map(|agg| agg.agg.mem_used())
.sum::<usize>()
}
}