blob: 32570fb22abe1b4ac6abdd471d530de53a5b5d5d [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
any::Any,
fmt::{Debug, Formatter},
sync::{atomic::AtomicUsize, Arc},
};
use arrow::{array::*, datatypes::*};
use datafusion::{
common::{
cast::{as_decimal128_array, as_int64_array},
Result, ScalarValue,
},
physical_expr::PhysicalExpr,
};
use datafusion_ext_commons::df_unimplemented_err;
use crate::agg::{
acc::{AccumInitialValue, AccumStateValAddr, RefAccumStateRow},
count::AggCount,
sum::AggSum,
Agg, WithAggBufAddrs, WithMemTracking,
};
pub struct AggAvg {
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
agg_sum: AggSum,
agg_count: AggCount,
accums_initial: Vec<AccumInitialValue>,
final_merger: fn(ScalarValue, i64) -> ScalarValue,
mem_used_tracker: AtomicUsize,
}
impl WithAggBufAddrs for AggAvg {
fn set_accum_state_val_addrs(&mut self, accum_state_val_addrs: &[AccumStateValAddr]) {
self.agg_sum
.set_accum_state_val_addrs(accum_state_val_addrs);
self.agg_count
.set_accum_state_val_addrs(&accum_state_val_addrs[1..]);
}
}
impl WithMemTracking for AggAvg {
fn mem_used_tracker(&self) -> &AtomicUsize {
&self.mem_used_tracker
}
}
impl AggAvg {
pub fn try_new(child: Arc<dyn PhysicalExpr>, data_type: DataType) -> Result<Self> {
let agg_sum = AggSum::try_new(child.clone(), data_type.clone())?;
let agg_count = AggCount::try_new(child.clone(), DataType::Int64)?;
let accums_initial = [agg_sum.accums_initial(), agg_count.accums_initial()].concat();
let final_merger = get_final_merger(&data_type)?;
Ok(Self {
child,
data_type,
agg_sum,
agg_count,
accums_initial,
final_merger,
mem_used_tracker: AtomicUsize::new(0),
})
}
}
impl Debug for AggAvg {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Avg({:?})", self.child)
}
}
impl Agg for AggAvg {
fn as_any(&self) -> &dyn Any {
self
}
fn exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.child.clone()]
}
fn with_new_exprs(&self, exprs: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn Agg>> {
Ok(Arc::new(Self::try_new(
exprs[0].clone(),
self.data_type.clone(),
)?))
}
fn data_type(&self) -> &DataType {
&self.data_type
}
fn nullable(&self) -> bool {
true
}
fn accums_initial(&self) -> &[AccumInitialValue] {
&self.accums_initial
}
fn prepare_partial_args(&self, partial_inputs: &[ArrayRef]) -> Result<Vec<ArrayRef>> {
// cast arg1 to target data type
Ok(vec![datafusion_ext_commons::cast::cast(
&partial_inputs[0],
&self.data_type,
)?])
}
fn increase_acc_mem_used(&self, _acc: &mut RefAccumStateRow) {
// do nothing
}
fn partial_update(
&self,
acc: &mut RefAccumStateRow,
values: &[ArrayRef],
row_idx: usize,
) -> Result<()> {
self.agg_sum.partial_update(acc, values, row_idx)?;
self.agg_count.partial_update(acc, values, row_idx)?;
Ok(())
}
fn partial_batch_update(
&self,
accs: &mut [RefAccumStateRow],
values: &[ArrayRef],
) -> Result<()> {
self.agg_sum.partial_batch_update(accs, values)?;
self.agg_count.partial_batch_update(accs, values)?;
Ok(())
}
fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
self.agg_sum.partial_update_all(acc, values)?;
self.agg_count.partial_update_all(acc, values)?;
Ok(())
}
fn partial_merge(
&self,
acc1: &mut RefAccumStateRow,
acc2: &mut RefAccumStateRow,
) -> Result<()> {
self.agg_sum.partial_merge(acc1, acc2)?;
self.agg_count.partial_merge(acc1, acc2)?;
Ok(())
}
fn partial_batch_merge(
&self,
accs: &mut [RefAccumStateRow],
merging_accs: &mut [RefAccumStateRow],
) -> Result<()> {
self.agg_sum.partial_batch_merge(accs, merging_accs)?;
self.agg_count.partial_batch_merge(accs, merging_accs)?;
Ok(())
}
fn final_merge(&self, acc: &mut RefAccumStateRow) -> Result<ScalarValue> {
let sum = self.agg_sum.final_merge(acc)?;
let count = match self.agg_count.final_merge(acc)? {
ScalarValue::Int64(Some(count)) => count,
_ => unreachable!(),
};
let final_merger = self.final_merger;
Ok(final_merger(sum, count))
}
fn final_batch_merge(&self, accs: &mut [RefAccumStateRow]) -> Result<ArrayRef> {
let sums = self.agg_sum.final_batch_merge(accs)?;
let counts = self.agg_count.final_batch_merge(accs)?;
let counts_zero_free: Int64Array = as_int64_array(&counts)?.unary_opt(|count| {
let not_zero = !count.is_zero();
not_zero.then_some(count)
});
if let &DataType::Decimal128(prec, scale) = self.data_type() {
let sums = as_decimal128_array(&sums)?;
let counts = counts_zero_free;
let avgs =
arrow::compute::binary::<_, _, _, Decimal128Type>(&sums, &counts, |sum, count| {
sum.checked_div_euclid(count as i128).unwrap_or_default()
})?;
Ok(Arc::new(avgs.with_precision_and_scale(prec, scale)?))
} else {
let counts = counts_zero_free;
Ok(arrow::compute::kernels::numeric::div(
&arrow::compute::cast(&sums, &DataType::Float64)?,
&arrow::compute::cast(&counts, &DataType::Float64)?,
)?)
}
}
}
fn get_final_merger(dt: &DataType) -> Result<fn(ScalarValue, i64) -> ScalarValue> {
macro_rules! get_fn {
($ty:ident,f64) => {{
Ok(|sum: ScalarValue, count: i64| {
let avg = match sum {
ScalarValue::$ty(sum, ..) => ScalarValue::Float64(if !count.is_zero() {
sum.map(|sum| sum as f64 / (count as f64))
} else {
None
}),
_ => unreachable!(),
};
avg
})
}};
(Decimal128) => {{
Ok(|sum: ScalarValue, count: i64| {
let avg = match sum {
ScalarValue::Decimal128(sum, prec, scale) => ScalarValue::Decimal128(
if !count.is_zero() {
sum.map(|sum| sum / (count as i128))
} else {
None
},
prec,
scale,
),
_ => unreachable!(),
};
avg
})
}};
}
match dt {
DataType::Null => Ok(|_, _| ScalarValue::Null),
DataType::Float32 => get_fn!(Float32, f64),
DataType::Float64 => get_fn!(Float64, f64),
DataType::Int8 => get_fn!(Int8, f64),
DataType::Int16 => get_fn!(Int16, f64),
DataType::Int32 => get_fn!(Int32, f64),
DataType::Int64 => get_fn!(Int64, f64),
DataType::UInt8 => get_fn!(UInt8, f64),
DataType::UInt16 => get_fn!(UInt16, f64),
DataType::UInt32 => get_fn!(UInt32, f64),
DataType::UInt64 => get_fn!(UInt64, f64),
DataType::Decimal128(..) => get_fn!(Decimal128),
other => df_unimplemented_err!("unsupported data type in avg(): {other}"),
}
}