blob: 4cdeb029943e1f44aaf58877b61528bd3ab67de5 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::{
any::Any,
fmt::{Debug, Formatter},
sync::Arc,
};
use arrow::datatypes::SchemaRef;
use datafusion::{
common::Result,
execution::context::TaskContext,
physical_expr::EquivalenceProperties,
physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
SendableRecordBatchStream, Statistics,
execution_plan::{Boundedness, EmissionType},
metrics::ExecutionPlanMetricsSet,
},
};
use futures::StreamExt;
use once_cell::sync::OnceCell;
use crate::common::execution_context::ExecutionContext;
#[derive(Debug)]
pub struct LimitExec {
input: Arc<dyn ExecutionPlan>,
limit: usize,
offset: usize,
pub metrics: ExecutionPlanMetricsSet,
props: OnceCell<PlanProperties>,
}
impl LimitExec {
pub fn new(input: Arc<dyn ExecutionPlan>, limit: usize, offset: usize) -> Self {
Self {
input,
limit,
offset,
metrics: ExecutionPlanMetricsSet::new(),
props: OnceCell::new(),
}
}
}
impl DisplayAs for LimitExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "LimitExec(limit={},offset={})", self.limit, self.offset)
}
}
impl ExecutionPlan for LimitExec {
fn name(&self) -> &str {
"LimitExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn properties(&self) -> &PlanProperties {
self.props.get_or_init(|| {
PlanProperties::new(
EquivalenceProperties::new(self.schema()),
self.input.output_partitioning().clone(),
EmissionType::Both,
Boundedness::Bounded,
)
})
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::new(
children[0].clone(),
self.limit,
self.offset,
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics);
let input = exec_ctx.execute_with_input_stats(&self.input)?;
if self.offset == 0 {
execute_limit(input, self.limit, exec_ctx)
} else {
execute_limit_with_offset(input, self.limit, self.offset, exec_ctx)
}
}
fn statistics(&self) -> Result<Statistics> {
Statistics::with_fetch(
self.input.partition_statistics(None)?,
self.schema(),
Some(self.limit),
self.offset,
1,
)
}
}
fn execute_limit(
mut input: SendableRecordBatchStream,
limit: usize,
exec_ctx: Arc<ExecutionContext>,
) -> Result<SendableRecordBatchStream> {
Ok(exec_ctx
.clone()
.output_with_sender("Limit", move |sender| async move {
let mut remaining = limit;
while remaining > 0
&& let Some(mut batch) = input.next().await.transpose()?
{
if remaining < batch.num_rows() {
batch = batch.slice(0, remaining);
remaining = 0;
} else {
remaining -= batch.num_rows();
}
exec_ctx.baseline_metrics().record_output(batch.num_rows());
sender.send(batch).await;
}
Ok(())
}))
}
fn execute_limit_with_offset(
mut input: SendableRecordBatchStream,
limit: usize,
offset: usize,
exec_ctx: Arc<ExecutionContext>,
) -> Result<SendableRecordBatchStream> {
Ok(exec_ctx
.clone()
.output_with_sender("Limit", move |sender| async move {
let mut skip = offset;
let mut remaining = limit - skip;
while remaining > 0
&& let Some(mut batch) = input.next().await.transpose()?
{
if skip > 0 {
let rows = batch.num_rows();
if skip >= rows {
skip -= rows;
continue;
}
batch = batch.slice(skip, rows - skip);
skip = 0;
}
if remaining < batch.num_rows() {
batch = batch.slice(0, remaining);
remaining = 0;
} else {
remaining -= batch.num_rows();
}
exec_ctx.baseline_metrics().record_output(batch.num_rows());
sender.send(batch).await;
}
Ok(())
}))
}
#[cfg(test)]
mod test {
#![allow(deprecated)]
use std::sync::Arc;
use arrow::{
array::Int32Array,
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
use auron_memmgr::MemManager;
use datafusion::{
assert_batches_eq,
common::{Result, stats::Precision},
physical_plan::{ExecutionPlan, common, test::TestMemoryExec},
prelude::SessionContext,
};
use crate::limit_exec::LimitExec;
fn build_table_i32(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
) -> Result<RecordBatch> {
let schema = Schema::new(vec![
Field::new(a.0, DataType::Int32, false),
Field::new(b.0, DataType::Int32, false),
Field::new(c.0, DataType::Int32, false),
]);
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(Int32Array::from(a.1.clone())),
Arc::new(Int32Array::from(b.1.clone())),
Arc::new(Int32Array::from(c.1.clone())),
],
)?;
Ok(batch)
}
fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
) -> Result<Arc<dyn ExecutionPlan>> {
let batch = build_table_i32(a, b, c)?;
let schema = batch.schema();
Ok(Arc::new(TestMemoryExec::try_new(
&[vec![batch]],
schema,
None,
)?))
}
#[tokio::test]
async fn test_limit_exec() -> Result<()> {
MemManager::init(10000);
let input = build_table(
("a", &vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
("b", &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
("c", &vec![5, 6, 7, 8, 9, 0, 1, 2, 3, 4]),
)?;
let limit_exec = LimitExec::new(input, 2, 0);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let output = limit_exec.execute(0, task_ctx)?;
let batches = common::collect(output).await?;
let row_count = limit_exec.statistics()?.num_rows;
let expected = vec![
"+---+---+---+",
"| a | b | c |",
"+---+---+---+",
"| 9 | 0 | 5 |",
"| 8 | 1 | 6 |",
"+---+---+---+",
];
assert_batches_eq!(expected, &batches);
assert_eq!(row_count, Precision::Exact(2));
Ok(())
}
#[tokio::test]
async fn test_limit_with_offset() -> Result<()> {
let input = build_table(
("a", &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
("b", &vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
("c", &vec![5, 6, 7, 8, 9, 0, 1, 2, 3, 4]),
)?;
let limit_exec = LimitExec::new(input, 7, 5);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let output = limit_exec.execute(0, task_ctx)?;
let batches = common::collect(output).await?;
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
let expected = vec![
"+---+---+---+",
"| a | b | c |",
"+---+---+---+",
"| 5 | 4 | 0 |",
"| 6 | 3 | 1 |",
"+---+---+---+",
];
assert_batches_eq!(expected, &batches);
assert_eq!(row_count, 2);
Ok(())
}
}