blob: 1ddb549ae87d506336714769b0883a215376b5bd [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Defines `Avg` & `Mean` aggregate & accumulators
use arrow::array::{
Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray,
BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
};
use arrow::compute::sum;
use arrow::datatypes::{
ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType,
Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType,
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256,
};
use datafusion_common::types::{NativeType, logical_float64};
use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Coercion, Documentation, EmitTo, Expr,
GroupsAccumulator, ReversedUDAF, Signature, TypeSignature, TypeSignatureClass,
Volatility,
};
use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
filtered_null_mask, set_nulls,
};
use datafusion_functions_aggregate_common::utils::DecimalAverager;
use datafusion_macros::user_doc;
use log::debug;
use std::any::Any;
use std::fmt::Debug;
use std::mem::{size_of, size_of_val};
use std::sync::Arc;
make_udaf_expr_and_func!(
Avg,
avg,
expression,
"Returns the avg of a group of values.",
avg_udaf
);
pub fn avg_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
avg_udaf(),
vec![expr],
true,
None,
vec![],
None,
))
}
#[user_doc(
doc_section(label = "General Functions"),
description = "Returns the average of numeric values in the specified column.",
syntax_example = "avg(expression)",
sql_example = r#"```sql
> SELECT avg(column_name) FROM table_name;
+---------------------------+
| avg(column_name) |
+---------------------------+
| 42.75 |
+---------------------------+
```"#,
standard_argument(name = "expression",)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Avg {
signature: Signature,
aliases: Vec<String>,
}
impl Avg {
pub fn new() -> Self {
Self {
// Supported types smallint, int, bigint, real, double precision, decimal, or interval
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![Coercion::new_exact(
TypeSignatureClass::Decimal,
)]),
TypeSignature::Coercible(vec![Coercion::new_exact(
TypeSignatureClass::Duration,
)]),
TypeSignature::Coercible(vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
NativeType::Float64,
)]),
],
Volatility::Immutable,
),
aliases: vec![String::from("mean")],
}
}
}
impl Default for Avg {
fn default() -> Self {
Self::new()
}
}
impl AggregateUDFImpl for Avg {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"avg"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
DataType::Decimal32(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal32(new_precision, new_scale))
}
DataType::Decimal64(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal64(new_precision, new_scale))
}
DataType::Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal128(new_precision, new_scale))
}
DataType::Decimal256(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal256(new_precision, new_scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
_ => Ok(DataType::Float64),
}
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.expr_fields[0].data_type();
use DataType::*;
// instantiate specialized accumulator based for the type
if acc_args.is_distinct {
match (data_type, acc_args.return_type()) {
// Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
(Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
(
Decimal32(_, scale),
Decimal32(target_precision, target_scale),
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
*scale,
*target_precision,
*target_scale,
))),
(
Decimal64(_, scale),
Decimal64(target_precision, target_scale),
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
*scale,
*target_precision,
*target_scale,
))),
(
Decimal128(_, scale),
Decimal128(target_precision, target_scale),
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
*scale,
*target_precision,
*target_scale,
))),
(
Decimal256(_, scale),
Decimal256(target_precision, target_scale),
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
*scale,
*target_precision,
*target_scale,
))),
(dt, return_type) => exec_err!(
"AVG(DISTINCT) for ({} --> {}) not supported",
dt,
return_type
),
}
} else {
match (&data_type, acc_args.return_type()) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
(
Decimal32(sum_precision, sum_scale),
Decimal32(target_precision, target_scale),
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
sum: None,
count: 0,
sum_scale: *sum_scale,
sum_precision: *sum_precision,
target_precision: *target_precision,
target_scale: *target_scale,
})),
(
Decimal64(sum_precision, sum_scale),
Decimal64(target_precision, target_scale),
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
sum: None,
count: 0,
sum_scale: *sum_scale,
sum_precision: *sum_precision,
target_precision: *target_precision,
target_scale: *target_scale,
})),
(
Decimal128(sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
sum: None,
count: 0,
sum_scale: *sum_scale,
sum_precision: *sum_precision,
target_precision: *target_precision,
target_scale: *target_scale,
})),
(
Decimal256(sum_precision, sum_scale),
Decimal256(target_precision, target_scale),
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
sum: None,
count: 0,
sum_scale: *sum_scale,
sum_precision: *sum_precision,
target_precision: *target_precision,
target_scale: *target_scale,
})),
(Duration(time_unit), Duration(result_unit)) => {
Ok(Box::new(DurationAvgAccumulator {
sum: None,
count: 0,
time_unit: *time_unit,
result_unit: *result_unit,
}))
}
(dt, return_type) => {
exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
}
}
}
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
if args.is_distinct {
// Decimal accumulator actually uses a different precision during accumulation,
// see DecimalDistinctAvgAccumulator::with_decimal_params
let dt = match args.input_fields[0].data_type() {
DataType::Decimal32(_, scale) => {
DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
}
DataType::Decimal64(_, scale) => {
DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
}
DataType::Decimal128(_, scale) => {
DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
}
DataType::Decimal256(_, scale) => {
DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
}
_ => args.return_type().clone(),
};
// Similar to datafusion_functions_aggregate::sum::Sum::state_fields
// since the accumulator uses DistinctSumAccumulator internally.
Ok(vec![
Field::new_list(
format_state_name(args.name, "avg distinct"),
Field::new_list_field(dt, true),
false,
)
.into(),
])
} else {
Ok(vec![
Field::new(
format_state_name(args.name, "count"),
DataType::UInt64,
true,
),
Field::new(
format_state_name(args.name, "sum"),
args.input_fields[0].data_type().clone(),
true,
),
]
.into_iter()
.map(Arc::new)
.collect())
}
}
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
matches!(
args.return_field.data_type(),
DataType::Float64
| DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::Duration(_)
) && !args.is_distinct
}
fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;
let data_type = args.expr_fields[0].data_type();
// instantiate specialized accumulator based for the type
match (data_type, args.return_field.data_type()) {
(Float64, Float64) => {
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
data_type,
args.return_field.data_type(),
|sum: f64, count: u64| Ok(sum / count as f64),
)))
}
(
Decimal32(_sum_precision, sum_scale),
Decimal32(target_precision, target_scale),
) => {
let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
*sum_scale,
*target_precision,
*target_scale,
)?;
let avg_fn =
move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
data_type,
args.return_field.data_type(),
avg_fn,
)))
}
(
Decimal64(_sum_precision, sum_scale),
Decimal64(target_precision, target_scale),
) => {
let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
*sum_scale,
*target_precision,
*target_scale,
)?;
let avg_fn =
move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
data_type,
args.return_field.data_type(),
avg_fn,
)))
}
(
Decimal128(_sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
) => {
let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
*sum_scale,
*target_precision,
*target_scale,
)?;
let avg_fn =
move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
data_type,
args.return_field.data_type(),
avg_fn,
)))
}
(
Decimal256(_sum_precision, sum_scale),
Decimal256(target_precision, target_scale),
) => {
let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
*sum_scale,
*target_precision,
*target_scale,
)?;
let avg_fn = move |sum: i256, count: u64| {
decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
};
Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
data_type,
args.return_field.data_type(),
avg_fn,
)))
}
(Duration(time_unit), Duration(_result_unit)) => {
let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
match time_unit {
TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
DurationSecondType,
_,
>::new(
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationMillisecondType,
_,
>::new(
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationMicrosecondType,
_,
>::new(
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationNanosecondType,
_,
>::new(
data_type,
args.return_type(),
avg_fn,
))),
}
}
_ => not_impl_err!(
"AvgGroupsAccumulator for ({} --> {})",
&data_type,
args.return_field.data_type()
),
}
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
/// An accumulator to compute the average
#[derive(Debug, Default)]
pub struct AvgAccumulator {
sum: Option<f64>,
count: u64,
}
impl Accumulator for AvgAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<Float64Type>();
self.count += (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
let v = self.sum.get_or_insert(0.);
*v += x;
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
}
fn size(&self) -> usize {
size_of_val(self)
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::Float64(self.sum),
])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
// counts are summed
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
// sums are summed
if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
let v = self.sum.get_or_insert(0.);
*v += x;
}
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<Float64Type>();
self.count -= (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
self.sum = Some(self.sum.unwrap() - x);
}
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
}
/// An accumulator to compute the average for decimals
#[derive(Debug)]
struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType + Debug> {
sum: Option<T::Native>,
count: u64,
sum_scale: i8,
sum_precision: u8,
target_precision: u8,
target_scale: i8,
}
impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumulator<T> {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
self.count += (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
let v = self.sum.get_or_insert_with(T::Native::default);
self.sum = Some(v.add_wrapping(x));
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let v = self
.sum
.map(|v| {
DecimalAverager::<T>::try_new(
self.sum_scale,
self.target_precision,
self.target_scale,
)?
.avg(v, T::Native::from_usize(self.count as usize).unwrap())
})
.transpose()?;
ScalarValue::new_primitive::<T>(
v,
&T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
)
}
fn size(&self) -> usize {
size_of_val(self)
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::new_primitive::<T>(
self.sum,
&T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
)?,
])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
// counts are summed
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
// sums are summed
if let Some(x) = sum(states[1].as_primitive::<T>()) {
let v = self.sum.get_or_insert_with(T::Native::default);
self.sum = Some(v.add_wrapping(x));
}
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
self.count -= (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
self.sum = Some(self.sum.unwrap().sub_wrapping(x));
}
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
}
/// An accumulator to compute the average for duration values
#[derive(Debug)]
struct DurationAvgAccumulator {
sum: Option<i64>,
count: u64,
time_unit: TimeUnit,
result_unit: TimeUnit,
}
impl Accumulator for DurationAvgAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count += (array.len() - array.null_count()) as u64;
let sum_value = match self.time_unit {
TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
};
if let Some(x) = sum_value {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let avg = self.sum.map(|sum| sum / self.count as i64);
match self.result_unit {
TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
}
}
fn size(&self) -> usize {
size_of_val(self)
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let duration_value = match self.time_unit {
TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
};
Ok(vec![ScalarValue::from(self.count), duration_value])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
let sum_value = match self.time_unit {
TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => {
sum(states[1].as_primitive::<DurationMillisecondType>())
}
TimeUnit::Microsecond => {
sum(states[1].as_primitive::<DurationMicrosecondType>())
}
TimeUnit::Nanosecond => {
sum(states[1].as_primitive::<DurationNanosecondType>())
}
};
if let Some(x) = sum_value {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count -= (array.len() - array.null_count()) as u64;
let sum_value = match self.time_unit {
TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
};
if let Some(x) = sum_value {
self.sum = Some(self.sum.unwrap() - x);
}
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
}
/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
/// Stores values as native types, and does overflow checking
///
/// F: Function that calculates the average value from a sum of
/// T::Native and a total count
#[derive(Debug)]
struct AvgGroupsAccumulator<T, F>
where
T: ArrowNumericType + Send,
F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
{
/// The type of the internal sum
sum_data_type: DataType,
/// The type of the returned sum
return_data_type: DataType,
/// Count per group (use u64 to make UInt64Array)
counts: Vec<u64>,
/// Sums per group, stored as the native type
sums: Vec<T::Native>,
/// Track nulls in the input / filters
null_state: NullState,
/// Function that computes the final average (value / count)
avg_fn: F,
}
impl<T, F> AvgGroupsAccumulator<T, F>
where
T: ArrowNumericType + Send,
F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
{
pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
debug!(
"AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}",
std::any::type_name::<T>()
);
Self {
return_data_type: return_data_type.clone(),
sum_data_type: sum_data_type.clone(),
counts: vec![],
sums: vec![],
null_state: NullState::new(),
avg_fn,
}
}
}
impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
where
T: ArrowNumericType + Send,
F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
{
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();
// increment counts, update sums
self.counts.resize(total_num_groups, 0);
self.sums.resize(total_num_groups, T::default_value());
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
// SAFETY: group_index is guaranteed to be in bounds
let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
*sum = sum.add_wrapping(new_value);
self.counts[group_index] += 1;
},
);
Ok(())
}
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let counts = emit_to.take_needed(&mut self.counts);
let sums = emit_to.take_needed(&mut self.sums);
let nulls = self.null_state.build(emit_to);
if let Some(nulls) = &nulls {
assert_eq!(nulls.len(), sums.len());
}
assert_eq!(counts.len(), sums.len());
// don't evaluate averages with null inputs to avoid errors on null values
let array: PrimitiveArray<T> = if let Some(nulls) = &nulls
&& nulls.null_count() > 0
{
let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
.with_data_type(self.return_data_type.clone());
let iter = sums.into_iter().zip(counts).zip(nulls.iter());
for ((sum, count), is_valid) in iter {
if is_valid {
builder.append_value((self.avg_fn)(sum, count)?)
} else {
builder.append_null();
}
}
builder.finish()
} else {
let averages: Vec<T::Native> = sums
.into_iter()
.zip(counts.into_iter())
.map(|(sum, count)| (self.avg_fn)(sum, count))
.collect::<Result<Vec<_>>>()?;
PrimitiveArray::new(averages.into(), nulls) // no copy
.with_data_type(self.return_data_type.clone())
};
Ok(Arc::new(array))
}
// return arrays for sums and counts
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let nulls = self.null_state.build(emit_to);
let counts = emit_to.take_needed(&mut self.counts);
let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
let sums = emit_to.take_needed(&mut self.sums);
let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
.with_data_type(self.sum_data_type.clone());
Ok(vec![
Arc::new(counts) as ArrayRef,
Arc::new(sums) as ArrayRef,
])
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 2, "two arguments to merge_batch");
// first batch is counts, second is partial sums
let partial_counts = values[0].as_primitive::<UInt64Type>();
let partial_sums = values[1].as_primitive::<T>();
// update counts with partial counts
self.counts.resize(total_num_groups, 0);
self.null_state.accumulate(
group_indices,
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
// SAFETY: group_index is guaranteed to be in bounds
let count = unsafe { self.counts.get_unchecked_mut(group_index) };
*count += partial_count;
},
);
// update sums
self.sums.resize(total_num_groups, T::default_value());
self.null_state.accumulate(
group_indices,
partial_sums,
opt_filter,
total_num_groups,
|group_index, new_value: <T as ArrowPrimitiveType>::Native| {
// SAFETY: group_index is guaranteed to be in bounds
let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
*sum = sum.add_wrapping(new_value);
},
);
Ok(())
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let sums = values[0]
.as_primitive::<T>()
.clone()
.with_data_type(self.sum_data_type.clone());
let counts = UInt64Array::from_value(1, sums.len());
let nulls = filtered_null_mask(opt_filter, &sums);
// set nulls on the arrays
let counts = set_nulls(counts, nulls.clone());
let sums = set_nulls(sums, nulls);
Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
}
fn supports_convert_to_state(&self) -> bool {
true
}
fn size(&self) -> usize {
self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
}
}