blob: 3f174d25a5d7de06fdda1df477ca9ed2ac4b737d [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
any::Any,
fmt::{Debug, Formatter},
sync::{atomic::AtomicUsize, Arc},
};
use arrow::{array::*, datatypes::*};
use datafusion::{
common::{Result, ScalarValue},
physical_expr::PhysicalExpr,
};
use datafusion_ext_commons::downcast_any;
use paste::paste;
use crate::agg::{
acc::{
AccumInitialValue, AccumStateRow, AccumStateValAddr, AggDynBinary, AggDynScalar, AggDynStr,
AggDynValue, RefAccumStateRow,
},
default_final_batch_merge_with_addr, default_final_merge_with_addr, Agg, WithAggBufAddrs,
WithMemTracking,
};
pub struct AggFirstIgnoresNull {
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
accums_initial: Vec<AccumInitialValue>,
accum_state_val_addr: AccumStateValAddr,
partial_updater: fn(&AggFirstIgnoresNull, &mut RefAccumStateRow, &ArrayRef, usize),
partial_buf_merger: fn(&AggFirstIgnoresNull, &mut RefAccumStateRow, &mut RefAccumStateRow),
mem_used_tracker: AtomicUsize,
}
impl WithAggBufAddrs for AggFirstIgnoresNull {
fn set_accum_state_val_addrs(&mut self, accum_state_val_addrs: &[AccumStateValAddr]) {
self.accum_state_val_addr = accum_state_val_addrs[0];
}
}
impl WithMemTracking for AggFirstIgnoresNull {
fn mem_used_tracker(&self) -> &AtomicUsize {
&self.mem_used_tracker
}
}
impl AggFirstIgnoresNull {
pub fn try_new(child: Arc<dyn PhysicalExpr>, data_type: DataType) -> Result<Self> {
let accums_initial = vec![AccumInitialValue::Scalar(ScalarValue::try_from(
&data_type,
)?)];
let partial_updater = get_partial_updater(&data_type)?;
let partial_buf_merger = get_partial_buf_merger(&data_type)?;
Ok(Self {
child,
data_type,
accums_initial,
accum_state_val_addr: AccumStateValAddr::default(),
partial_updater,
partial_buf_merger,
mem_used_tracker: AtomicUsize::new(0),
})
}
}
impl Debug for AggFirstIgnoresNull {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "FirstIgnoresNull({:?})", self.child)
}
}
impl Agg for AggFirstIgnoresNull {
fn as_any(&self) -> &dyn Any {
self
}
fn exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.child.clone()]
}
fn with_new_exprs(&self, exprs: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn Agg>> {
Ok(Arc::new(Self::try_new(
exprs[0].clone(),
self.data_type.clone(),
)?))
}
fn data_type(&self) -> &DataType {
&self.data_type
}
fn nullable(&self) -> bool {
true
}
fn accums_initial(&self) -> &[AccumInitialValue] {
&self.accums_initial
}
fn increase_acc_mem_used(&self, acc: &mut RefAccumStateRow) {
if self.data_type.is_primitive()
|| matches!(self.data_type, DataType::Null | DataType::Boolean)
{
return;
}
if let Some(v) = acc.dyn_value(self.accum_state_val_addr) {
self.add_mem_used(v.mem_size());
}
}
fn partial_update(
&self,
acc: &mut RefAccumStateRow,
values: &[ArrayRef],
row_idx: usize,
) -> Result<()> {
let partial_updater = self.partial_updater;
let value = &values[0];
partial_updater(self, acc, value, row_idx);
Ok(())
}
fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
let partial_updater = self.partial_updater;
let value = &values[0];
for i in 0..value.len() {
if value.is_valid(i) {
partial_updater(self, acc, value, i);
}
}
Ok(())
}
fn partial_merge(
&self,
acc1: &mut RefAccumStateRow,
acc2: &mut RefAccumStateRow,
) -> Result<()> {
let partial_buf_merger = self.partial_buf_merger;
partial_buf_merger(self, acc1, acc2);
Ok(())
}
fn final_merge(&self, acc: &mut RefAccumStateRow) -> Result<ScalarValue> {
default_final_merge_with_addr(self, acc, self.accum_state_val_addr)
}
fn final_batch_merge(&self, accs: &mut [RefAccumStateRow]) -> Result<ArrayRef> {
default_final_batch_merge_with_addr(self, accs, self.accum_state_val_addr)
}
}
fn get_partial_updater(
dt: &DataType,
) -> Result<fn(&AggFirstIgnoresNull, &mut RefAccumStateRow, &ArrayRef, usize)> {
macro_rules! fn_fixed {
($ty:ident) => {{
Ok(|this, acc, v, i| {
if !acc.is_fixed_valid(this.accum_state_val_addr) && v.is_valid(i) {
let value = v.as_any().downcast_ref::<paste! {[<$ty Array>]}>().unwrap();
acc.set_fixed_value(this.accum_state_val_addr, value.value(i));
acc.set_fixed_valid(this.accum_state_val_addr, true);
}
})
}};
}
match dt {
DataType::Null => Ok(|_, _, _, _| ()),
DataType::Boolean => fn_fixed!(Boolean),
DataType::Float32 => fn_fixed!(Float32),
DataType::Float64 => fn_fixed!(Float64),
DataType::Int8 => fn_fixed!(Int8),
DataType::Int16 => fn_fixed!(Int16),
DataType::Int32 => fn_fixed!(Int32),
DataType::Int64 => fn_fixed!(Int64),
DataType::UInt8 => fn_fixed!(UInt8),
DataType::UInt16 => fn_fixed!(UInt16),
DataType::UInt32 => fn_fixed!(UInt32),
DataType::UInt64 => fn_fixed!(UInt64),
DataType::Date32 => fn_fixed!(Date32),
DataType::Date64 => fn_fixed!(Date64),
DataType::Timestamp(TimeUnit::Second, _) => fn_fixed!(TimestampSecond),
DataType::Timestamp(TimeUnit::Millisecond, _) => fn_fixed!(TimestampMillisecond),
DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond),
DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond),
DataType::Decimal128(..) => fn_fixed!(Decimal128),
DataType::Utf8 => Ok(
|this: &AggFirstIgnoresNull, acc: &mut RefAccumStateRow, v: &ArrayRef, i: usize| {
if v.is_valid(i) {
let v = downcast_any!(v, StringArray).unwrap().value(i);
let w = acc.dyn_value_mut(this.accum_state_val_addr);
if w.is_none() {
let new = AggDynStr::from_str(v);
this.add_mem_used(new.mem_size());
*w = Some(Box::new(new));
}
}
},
),
DataType::Binary => Ok(
|this: &AggFirstIgnoresNull, acc: &mut RefAccumStateRow, v: &ArrayRef, i: usize| {
if v.is_valid(i) {
let v = downcast_any!(v, BinaryArray).unwrap().value(i);
let w = acc.dyn_value_mut(this.accum_state_val_addr);
if w.is_none() {
let new = AggDynBinary::from_slice(v);
this.add_mem_used(new.mem_size());
*w = Some(Box::new(new));
}
}
},
),
_other => Ok(
|this: &AggFirstIgnoresNull, acc: &mut RefAccumStateRow, v: &ArrayRef, i: usize| {
if v.is_valid(i) {
let w = acc.dyn_value_mut(this.accum_state_val_addr);
if w.is_none() {
let new =
AggDynScalar::new(ScalarValue::try_from_array(v, i).expect(
"FirstIgnoresNull::partial_update error creating ScalarValue",
));
this.add_mem_used(new.mem_size());
*w = Some(Box::new(new));
}
}
},
),
}
}
fn get_partial_buf_merger(
dt: &DataType,
) -> Result<fn(&AggFirstIgnoresNull, &mut RefAccumStateRow, &mut RefAccumStateRow)> {
macro_rules! fn_fixed {
($ty:ident) => {{
Ok(|this, acc1, acc2| {
type TType = paste! {[<$ty Type>]};
type TNative = <TType as ArrowPrimitiveType>::Native;
if !acc1.is_fixed_valid(this.accum_state_val_addr)
&& acc2.is_fixed_valid(this.accum_state_val_addr)
{
acc1.set_fixed_value(
this.accum_state_val_addr,
acc2.fixed_value::<TNative>(this.accum_state_val_addr),
);
acc1.set_fixed_valid(this.accum_state_val_addr, true);
}
})
}};
}
match dt {
DataType::Null => Ok(|_, _, _| ()),
DataType::Boolean => Ok(|this, acc1, acc2| {
if !acc1.is_fixed_valid(this.accum_state_val_addr)
&& acc2.is_fixed_valid(this.accum_state_val_addr)
{
acc1.set_fixed_value(
this.accum_state_val_addr,
acc2.fixed_value::<bool>(this.accum_state_val_addr),
);
acc1.set_fixed_valid(this.accum_state_val_addr, true);
}
}),
DataType::Float32 => fn_fixed!(Float32),
DataType::Float64 => fn_fixed!(Float64),
DataType::Int8 => fn_fixed!(Int8),
DataType::Int16 => fn_fixed!(Int16),
DataType::Int32 => fn_fixed!(Int32),
DataType::Int64 => fn_fixed!(Int64),
DataType::UInt8 => fn_fixed!(UInt8),
DataType::UInt16 => fn_fixed!(UInt16),
DataType::UInt32 => fn_fixed!(UInt32),
DataType::UInt64 => fn_fixed!(UInt64),
DataType::Date32 => fn_fixed!(Date32),
DataType::Date64 => fn_fixed!(Date64),
DataType::Timestamp(TimeUnit::Second, _) => fn_fixed!(TimestampSecond),
DataType::Timestamp(TimeUnit::Millisecond, _) => fn_fixed!(TimestampMillisecond),
DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond),
DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond),
DataType::Decimal128(..) => fn_fixed!(Decimal128),
DataType::Utf8 | DataType::Binary | _ => Ok(|this, acc1, acc2| {
let w = acc1.dyn_value_mut(this.accum_state_val_addr);
let v = acc2.dyn_value_mut(this.accum_state_val_addr);
if w.is_none() && v.is_some() {
*w = std::mem::take(v);
} else {
if let Some(v) = v {
// v will be dropped
this.sub_mem_used(v.mem_size());
}
}
}),
}
}