blob: f3f34d9bfa9d6b9f35456d9c0c2b9ead7701ec11 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::utils::{is_valid_decimal_precision, unlikely};
use arrow::{
array::BooleanBufferBuilder,
buffer::{BooleanBuffer, NullBuffer},
};
use arrow_array::{
cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
};
use arrow_schema::{DataType, Field};
use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator};
use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::Volatility::Immutable;
use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature};
use std::{any::Any, ops::BitAnd, sync::Arc};
#[derive(Debug)]
pub struct SumDecimal {
/// Aggregate function signature
signature: Signature,
/// The data type of the SUM result. This will always be a decimal type
/// with the same precision and scale as specified in this struct
result_type: DataType,
/// Decimal precision
precision: u8,
/// Decimal scale
scale: i8,
}
impl SumDecimal {
pub fn try_new(data_type: DataType) -> DFResult<Self> {
// The `data_type` is the SUM result type passed from Spark side
let (precision, scale) = match data_type {
DataType::Decimal128(p, s) => (p, s),
_ => {
return Err(DataFusionError::Internal(
"Invalid data type for SumDecimal".into(),
))
}
};
Ok(Self {
signature: Signature::user_defined(Immutable),
result_type: data_type,
precision,
scale,
})
}
}
impl AggregateUDFImpl for SumDecimal {
fn as_any(&self) -> &dyn Any {
self
}
fn accumulator(&self, _args: AccumulatorArgs) -> DFResult<Box<dyn Accumulator>> {
Ok(Box::new(SumDecimalAccumulator::new(
self.precision,
self.scale,
)))
}
fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<Field>> {
let fields = vec![
Field::new(self.name(), self.result_type.clone(), self.is_nullable()),
Field::new("is_empty", DataType::Boolean, false),
];
Ok(fields)
}
fn name(&self) -> &str {
"sum"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(self.result_type.clone())
}
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> DFResult<Box<dyn GroupsAccumulator>> {
Ok(Box::new(SumDecimalGroupsAccumulator::new(
self.result_type.clone(),
self.precision,
)))
}
fn default_value(&self, _data_type: &DataType) -> DFResult<ScalarValue> {
ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
)
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}
fn is_nullable(&self) -> bool {
// SumDecimal is always nullable because overflows can cause null values
true
}
}
#[derive(Debug)]
struct SumDecimalAccumulator {
sum: i128,
is_empty: bool,
is_not_null: bool,
precision: u8,
scale: i8,
}
impl SumDecimalAccumulator {
fn new(precision: u8, scale: i8) -> Self {
Self {
sum: 0,
is_empty: true,
is_not_null: true,
precision,
scale,
}
}
fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
let v = unsafe { values.value_unchecked(idx) };
let (new_sum, is_overflow) = self.sum.overflowing_add(v);
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
// Overflow: set buffer accumulator to null
self.is_not_null = false;
return;
}
self.sum = new_sum;
self.is_not_null = true;
}
}
impl Accumulator for SumDecimalAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
assert_eq!(
values.len(),
1,
"Expect only one element in 'values' but found {}",
values.len()
);
if !self.is_empty && !self.is_not_null {
// This means there's a overflow in decimal, so we will just skip the rest
// of the computation
return Ok(());
}
let values = &values[0];
let data = values.as_primitive::<Decimal128Type>();
self.is_empty = self.is_empty && values.len() == values.null_count();
if values.null_count() == 0 {
for i in 0..data.len() {
self.update_single(data, i);
}
} else {
for i in 0..data.len() {
if data.is_null(i) {
continue;
}
self.update_single(data, i);
}
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
// For each group:
// 1. if `is_empty` is true, it means either there is no value or all values for the group
// are null, in this case we'll return null
// 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
// non-ANSI mode Spark returns null.
if self.is_empty || !self.is_not_null {
ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
)
} else {
ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)
}
}
fn size(&self) -> usize {
std::mem::size_of_val(self)
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
let sum = if self.is_not_null {
ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)?
} else {
ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
)?
};
Ok(vec![sum, ScalarValue::from(self.is_empty)])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
assert_eq!(
states.len(),
2,
"Expect two element in 'states' but found {}",
states.len()
);
assert_eq!(states[0].len(), 1);
assert_eq!(states[1].len(), 1);
let that_sum = states[0].as_primitive::<Decimal128Type>();
let that_is_empty = states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
let this_overflow = !self.is_empty && !self.is_not_null;
let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
self.is_not_null = !this_overflow && !that_overflow;
self.is_empty = self.is_empty && that_is_empty.value(0);
if self.is_not_null {
self.sum += that_sum.value(0);
}
Ok(())
}
}
struct SumDecimalGroupsAccumulator {
// Whether aggregate buffer for a particular group is null. True indicates it is not null.
is_not_null: BooleanBufferBuilder,
is_empty: BooleanBufferBuilder,
sum: Vec<i128>,
result_type: DataType,
precision: u8,
}
impl SumDecimalGroupsAccumulator {
fn new(result_type: DataType, precision: u8) -> Self {
Self {
is_not_null: BooleanBufferBuilder::new(0),
is_empty: BooleanBufferBuilder::new(0),
sum: Vec::new(),
result_type,
precision,
}
}
fn is_overflow(&self, index: usize) -> bool {
!self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
}
fn update_single(&mut self, group_index: usize, value: i128) {
if unlikely(self.is_overflow(group_index)) {
// This means there's a overflow in decimal, so we will just skip the rest
// of the computation
return;
}
self.is_empty.set_bit(group_index, false);
let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
// Overflow: set buffer accumulator to null
self.is_not_null.set_bit(group_index, false);
return;
}
self.sum[group_index] = new_sum;
self.is_not_null.set_bit(group_index, true)
}
}
fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
if builder.len() < capacity {
let additional = capacity - builder.len();
builder.append_n(additional, true);
}
}
/// Build a boolean buffer from the state and reset the state, based on the emit_to
/// strategy.
fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer {
let bool_state: BooleanBuffer = state.finish();
match emit_to {
EmitTo::All => bool_state,
EmitTo::First(n) => {
// split off the first N values in bool_state
let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect();
// reset the existing seen buffer
for seen in bool_state.iter().skip(*n) {
state.append(seen);
}
first_n_bools
}
}
}
impl GroupsAccumulator for SumDecimalGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
assert!(opt_filter.is_none(), "opt_filter is not supported yet");
assert_eq!(values.len(), 1);
let values = values[0].as_primitive::<Decimal128Type>();
let data = values.values();
// Update size for the accumulate states
self.sum.resize(total_num_groups, 0);
ensure_bit_capacity(&mut self.is_empty, total_num_groups);
ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
let iter = group_indices.iter().zip(data.iter());
if values.null_count() == 0 {
for (&group_index, &value) in iter {
self.update_single(group_index, value);
}
} else {
for (idx, (&group_index, &value)) in iter.enumerate() {
if values.is_null(idx) {
continue;
}
self.update_single(group_index, value);
}
}
Ok(())
}
fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
// For each group:
// 1. if `is_empty` is true, it means either there is no value or all values for the group
// are null, in this case we'll return null
// 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
// non-ANSI mode Spark returns null.
let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
let x = (!&is_empty).bitand(&nulls);
let result = emit_to.take_needed(&mut self.sum);
let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x)))
.with_data_type(self.result_type.clone());
Ok(Arc::new(result))
}
fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
let nulls = Some(NullBuffer::new(nulls));
let sum = emit_to.take_needed(&mut self.sum);
let sum = Decimal128Array::new(sum.into(), nulls.clone())
.with_data_type(self.result_type.clone());
let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
let is_empty = BooleanArray::new(is_empty, None);
Ok(vec![
Arc::new(sum) as ArrayRef,
Arc::new(is_empty) as ArrayRef,
])
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
assert_eq!(
values.len(),
2,
"Expected two arrays: 'sum' and 'is_empty', but found {}",
values.len()
);
assert!(opt_filter.is_none(), "opt_filter is not supported yet");
// Make sure we have enough capacity for the additional groups
self.sum.resize(total_num_groups, 0);
ensure_bit_capacity(&mut self.is_empty, total_num_groups);
ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
let that_sum = &values[0];
let that_sum = that_sum.as_primitive::<Decimal128Type>();
let that_is_empty = &values[1];
let that_is_empty = that_is_empty
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
group_indices
.iter()
.enumerate()
.for_each(|(idx, &group_index)| unsafe {
let this_overflow = self.is_overflow(group_index);
let that_is_empty = that_is_empty.value_unchecked(idx);
let that_overflow = !that_is_empty && that_sum.is_null(idx);
let is_overflow = this_overflow || that_overflow;
// This part follows the logic in Spark:
// `org.apache.spark.sql.catalyst.expressions.aggregate.Sum`
self.is_not_null.set_bit(group_index, !is_overflow);
self.is_empty.set_bit(
group_index,
self.is_empty.get_bit(group_index) && that_is_empty,
);
if !is_overflow {
// .. otherwise, the sum value for this particular index must not be null,
// and thus we merge both values and update this sum.
self.sum[group_index] += that_sum.value_unchecked(idx);
}
});
Ok(())
}
fn size(&self) -> usize {
self.sum.capacity() * std::mem::size_of::<i128>()
+ self.is_empty.capacity() / 8
+ self.is_not_null.capacity() / 8
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::*;
use arrow_array::builder::{Decimal128Builder, StringBuilder};
use arrow_array::RecordBatch;
use datafusion::execution::TaskContext;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::Result;
use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
use futures::StreamExt;
#[test]
fn invalid_data_type() {
assert!(SumDecimal::try_new(DataType::Int32).is_err());
}
#[tokio::test]
async fn sum_no_overflow() -> Result<()> {
let num_rows = 8192;
let batch = create_record_batch(num_rows);
let mut batches = Vec::new();
for _ in 0..10 {
batches.push(batch.clone());
}
let partitions = &[batches];
let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
let data_type = DataType::Decimal128(8, 2);
let schema = Arc::clone(&partitions[0][0].schema());
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), None).unwrap());
let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
data_type.clone(),
)?));
let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
.schema(Arc::clone(&schema))
.alias("sum")
.with_ignore_nulls(false)
.with_distinct(false)
.build()?;
let aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
vec![aggr_expr.into()],
vec![None], // no filter expressions
scan,
Arc::clone(&schema),
)?);
let mut stream = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = stream.next().await {
let _batch = batch?;
}
Ok(())
}
fn create_record_batch(num_rows: usize) -> RecordBatch {
let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
for i in 0..num_rows {
decimal_builder.append_value(i as i128);
string_builder.append_value(format!("this is string #{}", i % 1024));
}
let decimal_array = Arc::new(decimal_builder.finish());
let string_array = Arc::new(string_builder.finish());
let mut fields = vec![];
let mut columns: Vec<ArrayRef> = vec![];
// string column
fields.push(Field::new("c0", DataType::Utf8, false));
columns.push(string_array);
// decimal column
fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
columns.push(decimal_array);
let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}
}