blob: b123d0b897b876d9a47eba9bf53d31f36e011d9e [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{any::Any, fmt::Formatter, sync::Arc};
use arrow::{
array::{Array, ArrayRef},
datatypes::SchemaRef,
error::ArrowError,
record_batch::{RecordBatch, RecordBatchOptions},
};
use datafusion::{
common::{Result, Statistics},
execution::context::TaskContext,
physical_expr::PhysicalSortExpr,
physical_plan::{
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
stream::RecordBatchStreamAdapter,
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr,
SendableRecordBatchStream,
},
};
use datafusion_ext_commons::{cast::cast, streams::coalesce_stream::CoalesceInput};
use futures::{stream::once, StreamExt, TryFutureExt, TryStreamExt};
use crate::{
common::output::TaskOutputter,
window::{window_context::WindowContext, WindowExpr, WindowFunctionProcessor},
};
#[derive(Debug)]
pub struct WindowExec {
input: Arc<dyn ExecutionPlan>,
context: Arc<WindowContext>,
metrics: ExecutionPlanMetricsSet,
}
impl WindowExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
window_exprs: Vec<WindowExpr>,
partition_spec: Vec<Arc<dyn PhysicalExpr>>,
order_spec: Vec<PhysicalSortExpr>,
) -> Result<Self> {
let context = Arc::new(WindowContext::try_new(
input.schema(),
window_exprs,
partition_spec,
order_spec,
)?);
Ok(Self {
input,
context,
metrics: ExecutionPlanMetricsSet::new(),
})
}
}
impl DisplayAs for WindowExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "Window")
}
}
impl ExecutionPlan for WindowExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.context.output_schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
self.input.output_partitioning()
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
self.input.output_ordering()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::try_new(
children[0].clone(),
self.context.window_exprs.clone(),
self.context.partition_spec.clone(),
self.context.order_spec.clone(),
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
// at this moment only supports ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
let input = self.input.execute(partition, context.clone())?;
let coalesced = context.coalesce_with_default_batch_size(
input,
&BaselineMetrics::new(&self.metrics, partition),
)?;
let stream = execute_window(
coalesced,
context.clone(),
self.context.clone(),
BaselineMetrics::new(&self.metrics, partition),
)
.map_err(|e| ArrowError::ExternalError(Box::new(e)));
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
once(stream).try_flatten(),
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Result<Statistics> {
todo!()
}
}
async fn execute_window(
mut input: SendableRecordBatchStream,
task_context: Arc<TaskContext>,
context: Arc<WindowContext>,
metrics: BaselineMetrics,
) -> Result<SendableRecordBatchStream> {
let mut processors: Vec<Box<dyn WindowFunctionProcessor>> = context
.window_exprs
.iter()
.map(|expr: &WindowExpr| expr.create_processor(&context))
.collect::<Result<_>>()?;
// start processing input batches
let output_schema = context.output_schema.clone();
task_context.output_with_sender("Window", output_schema, |sender| async move {
while let Some(batch) = input.next().await.transpose()? {
let elapsed_time = metrics.elapsed_compute().clone();
let mut timer = elapsed_time.timer();
let window_cols: Vec<ArrayRef> = processors
.iter_mut()
.map(|processor| {
if context.partition_spec.is_empty() {
processor.process_batch_without_partitions(&context, &batch)
} else {
processor.process_batch(&context, &batch)
}
})
.collect::<Result<_>>()?;
let outputs: Vec<ArrayRef> = batch
.columns()
.iter()
.chain(&window_cols)
.zip(context.output_schema.fields())
.map(|(array, field)| {
if array.data_type() != field.data_type() {
return cast(&array, field.data_type());
}
Ok(array.clone())
})
.collect::<Result<_>>()?;
let output_batch = RecordBatch::try_new_with_options(
context.output_schema.clone(),
outputs,
&RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
)?;
metrics.record_output(output_batch.num_rows());
sender.send(Ok(output_batch), Some(&mut timer)).await;
}
Ok(())
})
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use arrow::{array::*, datatypes::*, record_batch::RecordBatch};
use datafusion::{
assert_batches_eq,
physical_expr::{expressions::Column, PhysicalSortExpr},
physical_plan::{memory::MemoryExec, ExecutionPlan},
prelude::SessionContext,
};
use crate::{
agg::AggFunction,
window::{WindowExpr, WindowFunction, WindowRankType},
window_exec::WindowExec,
};
fn build_table_i32(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
) -> RecordBatch {
let schema = Schema::new(vec![
Field::new(a.0, DataType::Int32, false),
Field::new(b.0, DataType::Int32, false),
Field::new(c.0, DataType::Int32, false),
]);
RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(Int32Array::from(a.1.clone())),
Arc::new(Int32Array::from(b.1.clone())),
Arc::new(Int32Array::from(c.1.clone())),
],
)
.unwrap()
}
fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
) -> Arc<dyn ExecutionPlan> {
let batch = build_table_i32(a, b, c);
let schema = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
}
#[tokio::test]
async fn test_window() -> Result<(), Box<dyn std::error::Error>> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
// test window
let input = build_table(
("a1", &vec![1, 1, 1, 1, 2, 3, 3]),
("b1", &vec![1, 2, 2, 3, 4, 1, 1]),
("c1", &vec![0, 0, 0, 0, 0, 0, 0]),
);
let window = Arc::new(WindowExec::try_new(
input,
vec![
WindowExpr::new(
WindowFunction::RankLike(WindowRankType::RowNumber),
vec![],
Arc::new(Field::new("b1_row_number", DataType::Int32, false)),
),
WindowExpr::new(
WindowFunction::RankLike(WindowRankType::Rank),
vec![],
Arc::new(Field::new("b1_rank", DataType::Int32, false)),
),
WindowExpr::new(
WindowFunction::RankLike(WindowRankType::DenseRank),
vec![],
Arc::new(Field::new("b1_dense_rank", DataType::Int32, false)),
),
WindowExpr::new(
WindowFunction::Agg(AggFunction::Sum),
vec![Arc::new(Column::new("b1", 1))],
Arc::new(Field::new("b1_sum", DataType::Int64, false)),
),
],
vec![Arc::new(Column::new("a1", 0))],
vec![PhysicalSortExpr {
expr: Arc::new(Column::new("b1", 1)),
options: Default::default(),
}],
)?);
let stream = window.execute(0, task_ctx.clone())?;
let batches = datafusion::physical_plan::common::collect(stream).await?;
let expected = vec![
"+----+----+----+---------------+---------+---------------+--------+",
"| a1 | b1 | c1 | b1_row_number | b1_rank | b1_dense_rank | b1_sum |",
"+----+----+----+---------------+---------+---------------+--------+",
"| 1 | 1 | 0 | 1 | 1 | 1 | 1 |",
"| 1 | 2 | 0 | 2 | 2 | 2 | 3 |",
"| 1 | 2 | 0 | 3 | 2 | 2 | 5 |",
"| 1 | 3 | 0 | 4 | 4 | 3 | 8 |",
"| 2 | 4 | 0 | 1 | 1 | 1 | 4 |",
"| 3 | 1 | 0 | 1 | 1 | 1 | 1 |",
"| 3 | 1 | 0 | 2 | 1 | 1 | 2 |",
"+----+----+----+---------------+---------+---------------+--------+",
];
assert_batches_eq!(expected, &batches);
// test window without partition by clause
let input = build_table(
("a1", &vec![1, 3, 3, 1, 1, 1, 2]),
("b1", &vec![1, 1, 1, 2, 2, 3, 4]),
("c1", &vec![0, 0, 0, 0, 0, 0, 0]),
);
let window = Arc::new(WindowExec::try_new(
input,
vec![
WindowExpr::new(
WindowFunction::RankLike(WindowRankType::RowNumber),
vec![],
Arc::new(Field::new("b1_row_number", DataType::Int32, false)),
),
WindowExpr::new(
WindowFunction::RankLike(WindowRankType::Rank),
vec![],
Arc::new(Field::new("b1_rank", DataType::Int32, false)),
),
WindowExpr::new(
WindowFunction::RankLike(WindowRankType::DenseRank),
vec![],
Arc::new(Field::new("b1_dense_rank", DataType::Int32, false)),
),
WindowExpr::new(
WindowFunction::Agg(AggFunction::Sum),
vec![Arc::new(Column::new("b1", 1))],
Arc::new(Field::new("b1_sum", DataType::Int64, false)),
),
],
vec![],
vec![PhysicalSortExpr {
expr: Arc::new(Column::new("b1", 1)),
options: Default::default(),
}],
)?);
let stream = window.execute(0, task_ctx.clone())?;
let batches = datafusion::physical_plan::common::collect(stream).await?;
let expected = vec![
"+----+----+----+---------------+---------+---------------+--------+",
"| a1 | b1 | c1 | b1_row_number | b1_rank | b1_dense_rank | b1_sum |",
"+----+----+----+---------------+---------+---------------+--------+",
"| 1 | 1 | 0 | 1 | 1 | 1 | 1 |",
"| 3 | 1 | 0 | 2 | 1 | 1 | 2 |",
"| 3 | 1 | 0 | 3 | 1 | 1 | 3 |",
"| 1 | 2 | 0 | 4 | 4 | 2 | 5 |",
"| 1 | 2 | 0 | 5 | 4 | 2 | 7 |",
"| 1 | 3 | 0 | 6 | 6 | 3 | 10 |",
"| 2 | 4 | 0 | 7 | 7 | 4 | 14 |",
"+----+----+----+---------------+---------+---------------+--------+",
];
assert_batches_eq!(expected, &batches);
Ok(())
}
}