| // Copyright 2022 The Blaze Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| use std::{ |
| any::Any, |
| fmt::{Debug, Formatter}, |
| sync::Arc, |
| }; |
| |
| use arrow::{ |
| array::{RecordBatch, RecordBatchOptions}, |
| datatypes::SchemaRef, |
| }; |
| use blaze_jni_bridge::conf::{IntConf, UDAF_FALLBACK_NUM_UDAFS_TRIGGER_SORT_AGG}; |
| use datafusion::{ |
| common::{Result, Statistics}, |
| error::DataFusionError, |
| execution::context::TaskContext, |
| physical_expr::EquivalenceProperties, |
| physical_plan::{ |
| metrics::{ExecutionPlanMetricsSet, MetricsSet}, |
| DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, |
| PlanProperties, SendableRecordBatchStream, |
| }, |
| }; |
| use datafusion_ext_commons::{batch_size, downcast_any}; |
| use futures::StreamExt; |
| use once_cell::sync::OnceCell; |
| |
| use crate::{ |
| agg::{ |
| agg::IdxSelection, |
| agg_ctx::AggContext, |
| agg_table::{AggTable, OwnedKey}, |
| spark_udaf_wrapper::SparkUDAFWrapper, |
| AggExecMode, AggExpr, GroupingExpr, |
| }, |
| common::{execution_context::ExecutionContext, timer_helper::TimerHelper}, |
| expand_exec::ExpandExec, |
| memmgr::MemManager, |
| project_exec::ProjectExec, |
| sort_exec::create_default_ascending_sort_exec, |
| }; |
| |
| #[derive(Debug)] |
| pub struct AggExec { |
| input: Arc<dyn ExecutionPlan>, |
| agg_ctx: Arc<AggContext>, |
| metrics: ExecutionPlanMetricsSet, |
| props: OnceCell<PlanProperties>, |
| } |
| |
| impl AggExec { |
| pub fn try_new( |
| exec_mode: AggExecMode, |
| groupings: Vec<GroupingExpr>, |
| aggs: Vec<AggExpr>, |
| supports_partial_skipping: bool, |
| input: Arc<dyn ExecutionPlan>, |
| ) -> Result<Self> { |
| // do not trigger partial skipping if input is ExpandExec |
| let is_expand_agg = match &input { |
| e if downcast_any!(e, ExpandExec).is_ok() => true, |
| e if downcast_any!(e, ProjectExec).is_ok() => { |
| downcast_any!(&e.children()[0], ExpandExec).is_ok() |
| } |
| _ => false, |
| }; |
| |
| let agg_ctx = Arc::new(AggContext::try_new( |
| exec_mode, |
| input.schema(), |
| groupings, |
| aggs, |
| supports_partial_skipping, |
| is_expand_agg, |
| )?); |
| |
| Ok(Self { |
| input, |
| agg_ctx, |
| metrics: ExecutionPlanMetricsSet::new(), |
| props: OnceCell::new(), |
| }) |
| } |
| } |
| |
| impl ExecutionPlan for AggExec { |
| fn name(&self) -> &str { |
| "AggExec" |
| } |
| |
| fn as_any(&self) -> &dyn Any { |
| self |
| } |
| |
| fn schema(&self) -> SchemaRef { |
| self.agg_ctx.output_schema.clone() |
| } |
| |
| fn properties(&self) -> &PlanProperties { |
| self.props.get_or_init(|| { |
| PlanProperties::new( |
| EquivalenceProperties::new(self.schema()), |
| self.input.output_partitioning().clone(), |
| ExecutionMode::Bounded, |
| ) |
| }) |
| } |
| |
| fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
| vec![&self.input] |
| } |
| |
| fn with_new_children( |
| self: Arc<Self>, |
| children: Vec<Arc<dyn ExecutionPlan>>, |
| ) -> Result<Arc<dyn ExecutionPlan>> { |
| Ok(Arc::new(Self { |
| input: children[0].clone(), |
| agg_ctx: self.agg_ctx.clone(), |
| metrics: ExecutionPlanMetricsSet::new(), |
| props: OnceCell::new(), |
| })) |
| } |
| |
| fn execute( |
| &self, |
| partition: usize, |
| context: Arc<TaskContext>, |
| ) -> Result<SendableRecordBatchStream> { |
| let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics); |
| let output = execute_agg(self.input.clone(), exec_ctx.clone(), self.agg_ctx.clone())?; |
| Ok(exec_ctx.coalesce_with_default_batch_size(output)) |
| } |
| |
| fn metrics(&self) -> Option<MetricsSet> { |
| Some(self.metrics.clone_inner()) |
| } |
| |
| fn statistics(&self) -> Result<Statistics> { |
| todo!() |
| } |
| } |
| |
| impl DisplayAs for AggExec { |
| fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { |
| write!(f, "Agg {:?}", self.agg_ctx) |
| } |
| } |
| |
| fn execute_agg( |
| input: Arc<dyn ExecutionPlan>, |
| exec_ctx: Arc<ExecutionContext>, |
| agg_ctx: Arc<AggContext>, |
| ) -> Result<SendableRecordBatchStream> { |
| if agg_ctx.groupings.is_empty() { |
| let input = exec_ctx.execute_with_input_stats(&input)?; |
| return execute_agg_no_grouping(input, exec_ctx, agg_ctx); |
| } |
| |
| Ok(match agg_ctx.exec_mode { |
| AggExecMode::HashAgg => { |
| let num_udafs_trigger_sort_agg = UDAF_FALLBACK_NUM_UDAFS_TRIGGER_SORT_AGG |
| .value() |
| .unwrap_or(1) as usize; |
| let num_udafs = agg_ctx |
| .aggs |
| .iter() |
| .filter(|agg| downcast_any!(agg.agg, SparkUDAFWrapper).is_ok()) |
| .count(); |
| if num_udafs >= num_udafs_trigger_sort_agg { |
| let input_sort_exec = create_default_ascending_sort_exec( |
| input, |
| &agg_ctx |
| .groupings |
| .iter() |
| .map(|g| g.expr.clone()) |
| .collect::<Vec<_>>(), |
| Some(exec_ctx.execution_plan_metrics().clone()), |
| false, // do not record output metric |
| ); |
| let input_sorted = exec_ctx.clone().execute(&input_sort_exec)?; |
| execute_agg_sorted(input_sorted, exec_ctx.clone(), agg_ctx)? |
| } else { |
| let input = exec_ctx.execute_with_input_stats(&input)?; |
| execute_agg_with_grouping_hash(input, exec_ctx, agg_ctx)? |
| } |
| } |
| AggExecMode::SortAgg => { |
| let input = exec_ctx.execute_with_input_stats(&input)?; |
| execute_agg_sorted(input, exec_ctx, agg_ctx)? |
| } |
| }) |
| } |
| |
| fn execute_agg_with_grouping_hash( |
| input_stream: SendableRecordBatchStream, |
| exec_ctx: Arc<ExecutionContext>, |
| agg_ctx: Arc<AggContext>, |
| ) -> Result<SendableRecordBatchStream> { |
| // create tables |
| let tables = Arc::new(AggTable::try_new(agg_ctx.clone(), exec_ctx.clone())?); |
| MemManager::register_consumer(tables.clone(), true); |
| |
| // start processing input batches |
| let mut coalesced = exec_ctx.coalesce_with_default_batch_size(input_stream); |
| |
| Ok(exec_ctx |
| .clone() |
| .output_with_sender("Agg", |sender| async move { |
| let elapsed_compute = exec_ctx.baseline_metrics().elapsed_compute().clone(); |
| sender.exclude_time(&elapsed_compute); |
| let _timer = elapsed_compute.timer(); |
| |
| log::info!( |
| "start hash aggregating, supports_partial_skipping={}, num_groupings={}, num_partial={}, num_partial_merge={}, num_final={}", |
| agg_ctx.supports_partial_skipping, |
| agg_ctx.groupings.len(), |
| agg_ctx.aggs.iter().filter(|agg| agg.mode.is_partial()).count(), |
| agg_ctx.aggs.iter().filter(|agg| agg.mode.is_partial_merge()).count(), |
| agg_ctx.aggs.iter().filter(|agg| agg.mode.is_final()).count(), |
| ); |
| let mut partial_skipping_triggered = false; |
| |
| while let Some(batch) = elapsed_compute |
| .exclude_timer_async(coalesced.next()) |
| .await |
| .transpose()? |
| { |
| // output records without aggregation if partial skipping is triggered |
| if partial_skipping_triggered { |
| let exec_ctx = exec_ctx.clone(); |
| let sender = sender.clone(); |
| agg_ctx |
| .process_partial_skipped(batch, exec_ctx, sender) |
| .await?; |
| continue; |
| } |
| |
| // insert or update rows into in-mem table |
| match tables.process_input_batch(batch).await { |
| Ok(()) => {} |
| Err(DataFusionError::Execution(s)) if s == "AGG_TRIGGER_PARTIAL_SKIPPING" => { |
| // trigger partial skipping: flush in-mem table and directly |
| // output rest records without aggregation |
| // note: current batch has been updated to table |
| tables.output(sender.clone()).await?; |
| partial_skipping_triggered = true; |
| continue; |
| } |
| Err(DataFusionError::Execution(s)) if s == "AGG_SPILL_PARTIAL_SKIPPING" => { |
| // never spill if partial skipping is enabled |
| // note: current batch has been updated to table |
| tables.output(sender.clone()).await?; |
| continue; |
| } |
| Err(err) => return Err(err), |
| } |
| } |
| tables.output(sender.clone()).await?; |
| Ok(()) |
| })) |
| } |
| |
| fn execute_agg_no_grouping( |
| input_stream: SendableRecordBatchStream, |
| exec_ctx: Arc<ExecutionContext>, |
| agg_ctx: Arc<AggContext>, |
| ) -> Result<SendableRecordBatchStream> { |
| let mut acc_table = agg_ctx.create_acc_table(1); |
| |
| // start processing input batches |
| let mut coalesced = exec_ctx.coalesce_with_default_batch_size(input_stream); |
| |
| // output |
| // in no-grouping mode, we always output only one record, so it is not |
| // necessary to record elapsed computed time. |
| Ok(exec_ctx |
| .clone() |
| .output_with_sender("Agg", move |sender| async move { |
| let elapsed_compute = exec_ctx.baseline_metrics().elapsed_compute().clone(); |
| sender.exclude_time(&elapsed_compute); |
| let _timer = elapsed_compute.timer(); |
| |
| while let Some(batch) = elapsed_compute |
| .exclude_timer_async(coalesced.next()) |
| .await |
| .transpose()? |
| { |
| agg_ctx.update_batch_to_acc_table( |
| &batch, |
| &mut acc_table, |
| IdxSelection::Single(0), |
| )?; |
| } |
| |
| let agg_columns = agg_ctx.build_agg_columns(&mut acc_table, IdxSelection::Single(0))?; |
| let batch = RecordBatch::try_new_with_options( |
| agg_ctx.output_schema.clone(), |
| agg_columns, |
| &RecordBatchOptions::new().with_row_count(Some(1)), |
| )?; |
| exec_ctx.baseline_metrics().record_output(1); |
| sender.send(batch).await; |
| log::info!("aggregate exec (no grouping) outputting one record"); |
| Ok(()) |
| })) |
| } |
| |
| fn execute_agg_sorted( |
| input: SendableRecordBatchStream, |
| exec_ctx: Arc<ExecutionContext>, |
| agg_ctx: Arc<AggContext>, |
| ) -> Result<SendableRecordBatchStream> { |
| let batch_size = batch_size(); |
| |
| // start processing input batches |
| let mut coalesced = exec_ctx.coalesce_with_default_batch_size(input); |
| |
| Ok(exec_ctx |
| .clone() |
| .output_with_sender("Agg", move |sender| async move { |
| let elapsed_compute = exec_ctx.baseline_metrics().elapsed_compute().clone(); |
| sender.exclude_time(&elapsed_compute); |
| let _timer = elapsed_compute.timer(); |
| |
| let mut staging_keys: Vec<OwnedKey> = vec![]; |
| let mut staging_acc_table = agg_ctx.create_acc_table(0); |
| let mut acc_indices = vec![]; |
| |
| macro_rules! flush_staging { |
| () => {{ |
| let batch = agg_ctx.convert_records_to_batch( |
| &staging_keys, |
| &mut staging_acc_table, |
| IdxSelection::Range(0, staging_keys.len()), |
| )?; |
| let num_rows = batch.num_rows(); |
| staging_keys.clear(); |
| staging_acc_table.resize(0); |
| exec_ctx.baseline_metrics().record_output(num_rows); |
| sender.send((batch)).await; |
| }}; |
| } |
| |
| while let Some(batch) = elapsed_compute |
| .exclude_timer_async(coalesced.next()) |
| .await |
| .transpose()? |
| { |
| // compute grouping rows |
| let grouping_rows = agg_ctx.create_grouping_rows(&batch)?; |
| |
| // update to current record |
| let mut batch_range_start = 0; |
| let mut batch_range_end = 0; |
| while batch_range_end < batch.num_rows() { |
| let grouping_row = &grouping_rows.row(batch_range_end); |
| let same_key = |
| matches!(staging_keys.last(), Some(k) if k == grouping_row.as_ref()); |
| if !same_key { |
| if staging_keys.len() >= batch_size { |
| agg_ctx.update_batch_slice_to_acc_table( |
| &batch, |
| batch_range_start, |
| batch_range_end, |
| &mut staging_acc_table, |
| IdxSelection::Indices(&acc_indices), |
| )?; |
| acc_indices.clear(); |
| batch_range_start = batch_range_end; |
| flush_staging!(); |
| } |
| staging_keys.push(OwnedKey::from(grouping_row.as_ref())); |
| } |
| acc_indices.push(staging_keys.len() - 1); |
| batch_range_end += 1; |
| } |
| |
| agg_ctx.update_batch_slice_to_acc_table( |
| &batch, |
| batch_range_start, |
| batch_range_end, |
| &mut staging_acc_table, |
| IdxSelection::Indices(&acc_indices), |
| )?; |
| acc_indices.clear(); |
| } |
| |
| if !staging_keys.is_empty() { |
| flush_staging!(); |
| } |
| Ok(()) |
| })) |
| } |
| |
| #[cfg(test)] |
| mod test { |
| use std::sync::Arc; |
| |
| use arrow::{ |
| array::Int32Array, |
| datatypes::{DataType, Field, Schema}, |
| record_batch::RecordBatch, |
| }; |
| use datafusion::{ |
| assert_batches_sorted_eq, |
| common::{Result, ScalarValue}, |
| physical_expr::{expressions as phys_expr, expressions::Column}, |
| physical_plan::{common, memory::MemoryExec, ExecutionPlan}, |
| prelude::SessionContext, |
| }; |
| |
| use crate::{ |
| agg::{ |
| agg::create_agg, |
| AggExecMode::HashAgg, |
| AggExpr, AggFunction, |
| AggMode::{Final, Partial}, |
| GroupingExpr, |
| }, |
| agg_exec::AggExec, |
| memmgr::MemManager, |
| }; |
| |
| fn build_table_i32( |
| a: (&str, &Vec<i32>), |
| b: (&str, &Vec<i32>), |
| c: (&str, &Vec<i32>), |
| d: (&str, &Vec<i32>), |
| e: (&str, &Vec<i32>), |
| f: (&str, &Vec<i32>), |
| g: (&str, &Vec<i32>), |
| h: (&str, &Vec<i32>), |
| ) -> RecordBatch { |
| let schema = Schema::new(vec![ |
| Field::new(a.0, DataType::Int32, false), |
| Field::new(b.0, DataType::Int32, false), |
| Field::new(c.0, DataType::Int32, false), |
| Field::new(d.0, DataType::Int32, false), |
| Field::new(e.0, DataType::Int32, false), |
| Field::new(f.0, DataType::Int32, false), |
| Field::new(g.0, DataType::Int32, false), |
| Field::new(h.0, DataType::Int32, false), |
| ]); |
| |
| RecordBatch::try_new( |
| Arc::new(schema), |
| vec![ |
| Arc::new(Int32Array::from(a.1.clone())), |
| Arc::new(Int32Array::from(b.1.clone())), |
| Arc::new(Int32Array::from(c.1.clone())), |
| Arc::new(Int32Array::from(d.1.clone())), |
| Arc::new(Int32Array::from(e.1.clone())), |
| Arc::new(Int32Array::from(f.1.clone())), |
| Arc::new(Int32Array::from(g.1.clone())), |
| Arc::new(Int32Array::from(h.1.clone())), |
| ], |
| ) |
| .unwrap() |
| } |
| |
| fn build_table( |
| a: (&str, &Vec<i32>), |
| b: (&str, &Vec<i32>), |
| c: (&str, &Vec<i32>), |
| d: (&str, &Vec<i32>), |
| e: (&str, &Vec<i32>), |
| f: (&str, &Vec<i32>), |
| g: (&str, &Vec<i32>), |
| h: (&str, &Vec<i32>), |
| ) -> Arc<dyn ExecutionPlan> { |
| let batch = build_table_i32(a, b, c, d, e, f, g, h); |
| let schema = batch.schema(); |
| Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) |
| } |
| |
| #[tokio::test(flavor = "multi_thread", worker_threads = 1)] |
| async fn test_agg() -> Result<()> { |
| MemManager::init(10000); |
| |
| let input = build_table( |
| ("a", &vec![2, 9, 3, 1, 0, 4, 6]), |
| ("b", &vec![1, 0, 0, 3, 5, 6, 3]), |
| ("c", &vec![7, 8, 7, 8, 9, 2, 5]), |
| ("d", &vec![-7, 86, 71, 83, 90, -2, 5]), |
| ("e", &vec![-7, 86, 71, 83, 90, -2, 5]), |
| ("f", &vec![0, 1, 2, 3, 4, 5, 6]), |
| ("g", &vec![6, 3, 6, 3, 1, 5, 4]), |
| ("h", &vec![6, 3, 6, 3, 1, 5, 4]), |
| ); |
| |
| let agg_expr_sum = create_agg( |
| AggFunction::Sum, |
| &[phys_expr::col("a", &input.schema())?], |
| &input.schema(), |
| DataType::Int64, |
| )?; |
| |
| let agg_expr_avg = create_agg( |
| AggFunction::Avg, |
| &[phys_expr::col("b", &input.schema())?], |
| &input.schema(), |
| DataType::Float64, |
| )?; |
| |
| let agg_expr_max = create_agg( |
| AggFunction::Max, |
| &[phys_expr::col("d", &input.schema())?], |
| &input.schema(), |
| DataType::Int32, |
| )?; |
| |
| let agg_expr_min = create_agg( |
| AggFunction::Min, |
| &[phys_expr::col("e", &input.schema())?], |
| &input.schema(), |
| DataType::Int32, |
| )?; |
| |
| let agg_expr_count = create_agg( |
| AggFunction::Count, |
| &[phys_expr::col("f", &input.schema())?], |
| &input.schema(), |
| DataType::Int64, |
| )?; |
| |
| let agg_expr_collectlist = create_agg( |
| AggFunction::CollectList, |
| &[phys_expr::col("g", &input.schema())?], |
| &input.schema(), |
| DataType::new_list(DataType::Int32, false), |
| )?; |
| |
| let agg_expr_collectset = create_agg( |
| AggFunction::CollectSet, |
| &[phys_expr::col("h", &input.schema())?], |
| &input.schema(), |
| DataType::new_list(DataType::Int32, false), |
| )?; |
| |
| let agg_expr_collectlist_nil = create_agg( |
| AggFunction::CollectList, |
| &[Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(None)))], |
| &input.schema(), |
| DataType::new_list(DataType::Utf8, false), |
| )?; |
| |
| let agg_expr_collectset_nil = create_agg( |
| AggFunction::CollectSet, |
| &[Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(None)))], |
| &input.schema(), |
| DataType::new_list(DataType::Utf8, false), |
| )?; |
| |
| let agg_expr_firstign = create_agg( |
| AggFunction::FirstIgnoresNull, |
| &[phys_expr::col("h", &input.schema())?], |
| &input.schema(), |
| DataType::Int32, |
| )?; |
| |
| let aggs_agg_expr = vec![ |
| AggExpr { |
| field_name: "agg_expr_sum".to_string(), |
| mode: Partial, |
| agg: agg_expr_sum, |
| }, |
| AggExpr { |
| field_name: "agg_expr_avg".to_string(), |
| mode: Partial, |
| agg: agg_expr_avg, |
| }, |
| AggExpr { |
| field_name: "agg_expr_max".to_string(), |
| mode: Partial, |
| agg: agg_expr_max, |
| }, |
| AggExpr { |
| field_name: "agg_expr_min".to_string(), |
| mode: Partial, |
| agg: agg_expr_min, |
| }, |
| AggExpr { |
| field_name: "agg_expr_count".to_string(), |
| mode: Partial, |
| agg: agg_expr_count, |
| }, |
| AggExpr { |
| field_name: "agg_expr_collectlist".to_string(), |
| mode: Partial, |
| agg: agg_expr_collectlist, |
| }, |
| AggExpr { |
| field_name: "agg_expr_collectset".to_string(), |
| mode: Partial, |
| agg: agg_expr_collectset, |
| }, |
| AggExpr { |
| field_name: "agg_expr_collectlist_nil".to_string(), |
| mode: Partial, |
| agg: agg_expr_collectlist_nil, |
| }, |
| AggExpr { |
| field_name: "agg_expr_collectset_nil".to_string(), |
| mode: Partial, |
| agg: agg_expr_collectset_nil, |
| }, |
| AggExpr { |
| field_name: "agg_agg_firstign".to_string(), |
| mode: Partial, |
| agg: agg_expr_firstign, |
| }, |
| ]; |
| |
| let agg_exec_partial = AggExec::try_new( |
| HashAgg, |
| vec![GroupingExpr { |
| field_name: "c".to_string(), |
| expr: Arc::new(Column::new("c", 2)), |
| }], |
| aggs_agg_expr.clone(), |
| false, |
| input, |
| )?; |
| |
| let agg_exec_final = AggExec::try_new( |
| HashAgg, |
| vec![GroupingExpr { |
| field_name: "c".to_string(), |
| expr: Arc::new(Column::new("c", 0)), |
| }], |
| aggs_agg_expr |
| .into_iter() |
| .map(|mut agg| { |
| agg.agg = agg |
| .agg |
| .with_new_exprs(vec![Arc::new(phys_expr::Literal::new( |
| ScalarValue::Null, |
| ))])?; |
| agg.mode = Final; |
| Ok(agg) |
| }) |
| .collect::<Result<_>>()?, |
| false, |
| Arc::new(agg_exec_partial), |
| )?; |
| |
| let session_ctx = SessionContext::new(); |
| let task_ctx = session_ctx.task_ctx(); |
| let output_final = agg_exec_final.execute(0, task_ctx)?; |
| let batches = common::collect(output_final).await?; |
| let expected = vec![ |
| "+---+--------------+--------------+--------------+--------------+----------------+----------------------+---------------------+--------------------------+-------------------------+------------------+", |
| "| c | agg_expr_sum | agg_expr_avg | agg_expr_max | agg_expr_min | agg_expr_count | agg_expr_collectlist | agg_expr_collectset | agg_expr_collectlist_nil | agg_expr_collectset_nil | agg_agg_firstign |", |
| "+---+--------------+--------------+--------------+--------------+----------------+----------------------+---------------------+--------------------------+-------------------------+------------------+", |
| "| 2 | 4 | 6.0 | -2 | -2 | 1 | [5] | [5] | [] | [] | 5 |", |
| "| 5 | 6 | 3.0 | 5 | 5 | 1 | [4] | [4] | [] | [] | 4 |", |
| "| 7 | 5 | 0.5 | 71 | -7 | 2 | [6, 6] | [6] | [] | [] | 6 |", |
| "| 8 | 10 | 1.5 | 86 | 83 | 2 | [3, 3] | [3] | [] | [] | 3 |", |
| "| 9 | 0 | 5.0 | 90 | 90 | 1 | [1] | [1] | [] | [] | 1 |", |
| "+---+--------------+--------------+--------------+--------------+----------------+----------------------+---------------------+--------------------------+-------------------------+------------------+", |
| ]; |
| assert_batches_sorted_eq!(expected, &batches); |
| Ok(()) |
| } |
| } |
| |
| #[cfg(test)] |
| mod fuzztest { |
| use std::{collections::HashMap, sync::Arc}; |
| |
| use arrow::{ |
| array::{Array, ArrayRef, AsArray, Float64Builder, Int64Builder}, |
| compute::concat_batches, |
| datatypes::{DataType, Float64Type, Int64Type}, |
| record_batch::RecordBatch, |
| }; |
| use datafusion::{ |
| common::Result, |
| physical_expr::expressions as phys_expr, |
| physical_plan::memory::MemoryExec, |
| prelude::{SessionConfig, SessionContext}, |
| }; |
| |
| use crate::{ |
| agg::{ |
| count::AggCount, |
| sum::AggSum, |
| AggExecMode::HashAgg, |
| AggExpr, |
| AggMode::{Final, Partial}, |
| GroupingExpr, |
| }, |
| agg_exec::AggExec, |
| memmgr::MemManager, |
| }; |
| |
| #[tokio::test(flavor = "multi_thread", worker_threads = 1)] |
| async fn fuzztest() -> Result<()> { |
| MemManager::init(1000); // small memory config to trigger spill |
| let session_ctx = |
| SessionContext::new_with_config(SessionConfig::new().with_batch_size(10000)); |
| let task_ctx = session_ctx.task_ctx(); |
| |
| let mut verify_sum_map: HashMap<i64, f64> = HashMap::new(); |
| let mut verify_cnt_map: HashMap<i64, i64> = HashMap::new(); |
| let mut batches = vec![]; |
| for _batch_id in 0..100 { |
| let mut key_builder = Int64Builder::new(); |
| let mut val_builder = Float64Builder::new(); |
| |
| for _ in 0..10000 { |
| // will trigger spill |
| let key = (rand::random::<u32>() % 1_000_000) as i64; |
| let val = (rand::random::<u32>() % 1_000_000) as f64; |
| let test_null = rand::random::<u32>() % 1000 == 0; |
| |
| key_builder.append_value(key); |
| if !test_null { |
| val_builder.append_null(); |
| continue; |
| } |
| val_builder.append_value(val); |
| verify_sum_map |
| .entry(key) |
| .and_modify(|v| *v += val) |
| .or_insert(val); |
| verify_cnt_map |
| .entry(key) |
| .and_modify(|v| *v += 1) |
| .or_insert(1); |
| } |
| let key_col: ArrayRef = Arc::new(key_builder.finish()); |
| let val_col: ArrayRef = Arc::new(val_builder.finish()); |
| let batch = RecordBatch::try_from_iter_with_nullable(vec![ |
| ("key", key_col, false), |
| ("val", val_col, true), |
| ])?; |
| batches.push(batch); |
| } |
| |
| let schema = batches[0].schema(); |
| let input = Arc::new(MemoryExec::try_new( |
| &[batches.clone()], |
| schema.clone(), |
| None, |
| )?); |
| let partial_agg = Arc::new(AggExec::try_new( |
| HashAgg, |
| vec![GroupingExpr { |
| field_name: format!("key"), |
| expr: phys_expr::col("key", &schema)?, |
| }], |
| vec![ |
| AggExpr { |
| field_name: "sum".to_string(), |
| mode: Partial, |
| agg: Arc::new(AggSum::try_new( |
| phys_expr::col("val", &schema)?, |
| DataType::Float64, |
| )?), |
| }, |
| AggExpr { |
| field_name: "cnt".to_string(), |
| mode: Partial, |
| agg: Arc::new(AggCount::try_new( |
| vec![phys_expr::col("val", &schema)?], |
| DataType::Int64, |
| )?), |
| }, |
| ], |
| true, |
| input, |
| )?); |
| let final_agg = Arc::new(AggExec::try_new( |
| HashAgg, |
| vec![GroupingExpr { |
| field_name: format!("key"), |
| expr: phys_expr::col("key", &schema)?, |
| }], |
| vec![ |
| AggExpr { |
| field_name: "sum".to_string(), |
| mode: Final, |
| agg: Arc::new(AggSum::try_new( |
| phys_expr::col("val", &schema)?, |
| DataType::Float64, |
| )?), |
| }, |
| AggExpr { |
| field_name: "cnt".to_string(), |
| mode: Final, |
| agg: Arc::new(AggCount::try_new( |
| vec![phys_expr::col("val", &schema)?], |
| DataType::Int64, |
| )?), |
| }, |
| ], |
| false, |
| partial_agg, |
| )?); |
| |
| let output = datafusion::physical_plan::collect(final_agg, task_ctx.clone()).await?; |
| let a = concat_batches(&output[0].schema(), &output)?; |
| |
| let key_col = a.column(0).as_primitive::<Int64Type>(); |
| let sum_col = a.column(1).as_primitive::<Float64Type>(); |
| let cnt_col = a.column(2).as_primitive::<Int64Type>(); |
| for i in 0..key_col.len() { |
| assert!(key_col.is_valid(i)); |
| assert!(cnt_col.is_valid(i)); |
| if sum_col.is_valid(i) { |
| let key = key_col.value(i); |
| let val = sum_col.value(i); |
| let cnt = cnt_col.value(i); |
| assert_eq!( |
| verify_sum_map[&key] as i64, val as i64, |
| "key={key}, sum not matched" |
| ); |
| assert_eq!(verify_cnt_map[&key], cnt, "key={key}, cnt not matched"); |
| } else { |
| let cnt = cnt_col.value(i); |
| assert_eq!(cnt, 0); |
| } |
| } |
| Ok(()) |
| } |
| } |