blob: ff178e115f0c7712ac4a75cd1aef143af6ea8872 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
any::Any,
fmt::{Debug, Formatter},
sync::Arc,
};
use arrow::{array::*, datatypes::*};
use datafusion::{
common::{Result, ScalarValue},
physical_expr::PhysicalExpr,
};
use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array};
use crate::{
agg::{
acc::{AccBytes, AccColumn, AccColumnRef, AccGenericColumn},
agg::IdxSelection,
Agg,
},
common::SliceAsRawBytes,
idx_for_zipped,
};
pub struct AggFirstIgnoresNull {
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
}
impl AggFirstIgnoresNull {
pub fn try_new(child: Arc<dyn PhysicalExpr>, data_type: DataType) -> Result<Self> {
Ok(Self { child, data_type })
}
}
impl Debug for AggFirstIgnoresNull {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "FirstIgnoresNull({:?})", self.child)
}
}
impl Agg for AggFirstIgnoresNull {
fn as_any(&self) -> &dyn Any {
self
}
fn exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.child.clone()]
}
fn with_new_exprs(&self, exprs: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn Agg>> {
Ok(Arc::new(Self::try_new(
exprs[0].clone(),
self.data_type.clone(),
)?))
}
fn data_type(&self) -> &DataType {
&self.data_type
}
fn nullable(&self) -> bool {
true
}
fn create_acc_column(&self, num_rows: usize) -> AccColumnRef {
Box::new(AccGenericColumn::new(&self.data_type, num_rows))
}
fn partial_update(
&self,
accs: &mut AccColumnRef,
acc_idx: IdxSelection<'_>,
partial_args: &[ArrayRef],
partial_arg_idx: IdxSelection<'_>,
) -> Result<()> {
let partial_arg = &partial_args[0];
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);
let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);
macro_rules! handle_bytes {
($ty:ident) => {{
type TArray = paste::paste! {[<$ty Array>]};
let partial_arg = downcast_any!(partial_arg, TArray).unwrap();
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if accs.bytes_value(acc_idx).is_none() && partial_arg.is_valid(partial_arg_idx) {
accs.set_bytes_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref())));
}
}
}
}}
}
downcast_primitive_array! {
partial_arg => {
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if !accs.prim_valid(acc_idx) && partial_arg.is_valid(partial_arg_idx) {
accs.set_prim_valid(acc_idx, true);
accs.set_prim_value(acc_idx, partial_arg.value(partial_arg_idx));
}
}
}
}
DataType::Utf8 => handle_bytes!(String),
DataType::Binary => handle_bytes!(Binary),
_other => {
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if accs.scalar_values()[acc_idx].is_null() && partial_arg.is_valid(partial_arg_idx) {
accs.scalar_values_mut()[acc_idx] = compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?;
}
}
}
}
}
let new_heap_mem_used = accs.items_heap_mem_used(acc_idx);
accs.add_heap_mem_used(new_heap_mem_used - old_heap_mem_used);
Ok(())
}
fn partial_merge(
&self,
accs: &mut AccColumnRef,
acc_idx: IdxSelection<'_>,
merging_accs: &mut AccColumnRef,
merging_acc_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);
let mut merging_accs = downcast_any!(merging_accs, mut AccGenericColumn).unwrap();
let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);
// safety: bypass borrow checker
let accs_values: &mut AccGenericColumn = unsafe { std::mem::transmute(&mut *accs) };
match (accs_values, &mut merging_accs) {
(
AccGenericColumn::Prim {
raw,
valids,
prim_size,
..
},
AccGenericColumn::Prim {
raw: other_raw,
valids: other_valids,
..
},
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if !valids[acc_idx] && other_valids[merging_acc_idx] {
valids.set(acc_idx, true);
let acc_offset = *prim_size * acc_idx;
let merging_acc_offset = *prim_size * merging_acc_idx;
raw.as_raw_bytes_mut()[acc_offset..][..*prim_size]
.copy_from_slice(&other_raw.as_raw_bytes()[merging_acc_offset..][..*prim_size]);
}
}
}
}
(
AccGenericColumn::Bytes { items, .. },
AccGenericColumn::Bytes {
items: other_items, ..
},
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
let item = &mut items[acc_idx];
let mut other_item = &mut other_items[merging_acc_idx];
if item.is_none() && other_item.is_some() {
*item = std::mem::take(&mut other_item);
}
}
}
}
(
AccGenericColumn::Scalar { items, .. },
AccGenericColumn::Scalar {
items: other_items, ..
},
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
let item = &mut items[acc_idx];
let mut other_item = &mut other_items[merging_acc_idx];
if item.is_null() && !other_item.is_null() {
*item = std::mem::replace(&mut other_item, ScalarValue::Null);
}
}
}
}
_ => unreachable!(),
}
let new_heap_mem_used = accs.items_heap_mem_used(acc_idx);
accs.add_heap_mem_used(new_heap_mem_used - old_heap_mem_used);
Ok(())
}
fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result<ArrayRef> {
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.to_array(acc_idx, &self.data_type)
}
}