blob: c6cc7969cfc4b9e8e8ec1bbabcb644bc41919c54 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
any::Any,
fmt::{Debug, Formatter},
sync::{atomic::AtomicUsize, Arc},
};
use arrow::{array::*, datatypes::*};
use datafusion::{
common::{Result, ScalarValue},
physical_expr::PhysicalExpr,
};
use datafusion_ext_commons::{df_execution_err, downcast_any};
use crate::agg::{
acc::{
AccumInitialValue, AccumStateRow, AccumStateValAddr, AggDynList, AggDynValue,
RefAccumStateRow,
},
Agg, WithAggBufAddrs, WithMemTracking,
};
pub struct AggCollectList {
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
arg_type: DataType,
accum_initial: [AccumInitialValue; 1],
accum_state_val_addr: AccumStateValAddr,
mem_used_tracker: AtomicUsize,
}
impl WithAggBufAddrs for AggCollectList {
fn set_accum_state_val_addrs(&mut self, accum_state_val_addrs: &[AccumStateValAddr]) {
self.accum_state_val_addr = accum_state_val_addrs[0];
}
}
impl WithMemTracking for AggCollectList {
fn mem_used_tracker(&self) -> &AtomicUsize {
&self.mem_used_tracker
}
}
impl AggCollectList {
pub fn try_new(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
arg_type: DataType,
) -> Result<Self> {
Ok(Self {
child,
data_type,
accum_initial: [AccumInitialValue::DynList(arg_type.clone())],
arg_type,
accum_state_val_addr: AccumStateValAddr::default(),
mem_used_tracker: AtomicUsize::new(0),
})
}
pub fn arg_type(&self) -> &DataType {
&self.arg_type
}
}
impl Debug for AggCollectList {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "CollectList({:?})", self.child)
}
}
impl Agg for AggCollectList {
fn as_any(&self) -> &dyn Any {
self
}
fn exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.child.clone()]
}
fn with_new_exprs(&self, exprs: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn Agg>> {
Ok(Arc::new(Self::try_new(
exprs[0].clone(),
self.data_type.clone(),
self.arg_type.clone(),
)?))
}
fn data_type(&self) -> &DataType {
&self.data_type
}
fn nullable(&self) -> bool {
false
}
fn accums_initial(&self) -> &[AccumInitialValue] {
&self.accum_initial
}
fn increase_acc_mem_used(&self, acc: &mut RefAccumStateRow) {
if let Some(v) = acc.dyn_value(self.accum_state_val_addr) {
self.add_mem_used(v.mem_size());
}
}
fn partial_update(
&self,
acc: &mut RefAccumStateRow,
values: &[ArrayRef],
row_idx: usize,
) -> Result<()> {
if values[0].is_valid(row_idx) {
match acc.dyn_value_mut(self.accum_state_val_addr) {
Some(dyn_list) => {
let list = downcast_any!(dyn_list, mut AggDynList)?;
self.sub_mem_used(list.mem_size());
list.append(&ScalarValue::try_from_array(&values[0], row_idx)?, false);
self.add_mem_used(list.mem_size());
}
w => {
let mut new_list = AggDynList::default();
new_list.append(&ScalarValue::try_from_array(&values[0], row_idx)?, false);
self.add_mem_used(new_list.mem_size());
*w = Some(Box::new(new_list));
}
}
}
Ok(())
}
fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
let dyn_list = match acc.dyn_value_mut(self.accum_state_val_addr) {
Some(dyn_list) => dyn_list,
w => {
let new_list = AggDynList::default();
self.add_mem_used(new_list.mem_size());
*w = Some(Box::new(new_list));
w.as_mut().unwrap()
}
};
let list = downcast_any!(dyn_list, mut AggDynList)?;
self.sub_mem_used(list.mem_size());
for i in 0..values[0].len() {
if values[0].is_valid(i) {
list.append(&ScalarValue::try_from_array(&values[0], i)?, false);
}
}
self.add_mem_used(list.mem_size());
Ok(())
}
fn partial_merge(
&self,
acc: &mut RefAccumStateRow,
merging_acc: &mut RefAccumStateRow,
) -> Result<()> {
match (
acc.dyn_value_mut(self.accum_state_val_addr),
merging_acc.dyn_value_mut(self.accum_state_val_addr),
) {
(Some(w), Some(v)) => {
let w = downcast_any!(w, mut AggDynList)?;
let v = downcast_any!(v, mut AggDynList)?;
self.sub_mem_used(w.mem_size());
self.sub_mem_used(v.mem_size());
w.merge(v);
self.add_mem_used(w.mem_size());
}
(w_none, v @ Some(_)) => *w_none = std::mem::take(v),
(None, _) => {}
(_, None) => {}
}
Ok(())
}
fn final_merge(&self, acc: &mut RefAccumStateRow) -> Result<ScalarValue> {
match std::mem::take(acc.dyn_value_mut(self.accum_state_val_addr)) {
Some(w) => {
let list = w
.as_any_boxed()
.downcast::<AggDynList>()
.or_else(|_| df_execution_err!("error downcasting to AggDynList"))?;
self.sub_mem_used(list.mem_size());
Ok(ScalarValue::List(ScalarValue::new_list(
&list
.into_values(self.arg_type.clone(), false)
.collect::<Vec<_>>(),
&self.arg_type,
)))
}
None => ScalarValue::try_from(&self.data_type),
}
}
fn final_batch_merge(&self, accs: &mut [RefAccumStateRow]) -> Result<ArrayRef> {
let values: Vec<ScalarValue> = accs
.iter_mut()
.map(|acc| self.final_merge(acc))
.collect::<Result<_>>()?;
if values.is_empty() {
return Ok(new_empty_array(self.data_type()));
}
Ok(ScalarValue::iter_to_array(values)?)
}
}