blob: f198ed899aff3838b47b9c0612428e13358b8409 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use arrow::record_batch::RecordBatch;
use datafusion::{
common::{
tree_node::{Transformed, TreeNode},
Result,
},
execution::{SendableRecordBatchStream, TaskContext},
physical_expr::{expressions::Column, utils::collect_columns, PhysicalExprRef},
physical_plan::{stream::RecordBatchStreamAdapter, ExecutionPlan},
};
use futures::StreamExt;
use itertools::Itertools;
pub trait ExecuteWithColumnPruning {
fn execute_projected(
&self,
partition: usize,
context: Arc<TaskContext>,
projection: &[usize],
) -> Result<SendableRecordBatchStream>;
}
impl ExecuteWithColumnPruning for dyn ExecutionPlan {
fn execute_projected(
&self,
partition: usize,
context: Arc<TaskContext>,
projection: &[usize],
) -> Result<SendableRecordBatchStream> {
if projection.iter().enumerate().all(|(i, &proj)| proj == i) {
return self.execute(partition, context);
}
let projection = projection.to_vec();
let schema = Arc::new(self.schema().project(&projection)?);
let stream =
self.execute(partition, context)?
.map(move |batch_result: Result<RecordBatch>| {
batch_result.and_then(|batch| Ok(batch.project(&projection)?))
});
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
}
pub fn prune_columns(exprs: &[PhysicalExprRef]) -> Result<(Vec<PhysicalExprRef>, Vec<usize>)> {
let required_columns: Vec<Column> = exprs
.iter()
.flat_map(|expr| collect_columns(expr).into_iter())
.collect::<HashSet<Column>>()
.into_iter()
.sorted_unstable_by_key(|column| column.index())
.collect();
let required_columns_mapping: HashMap<usize, usize> = required_columns
.iter()
.enumerate()
.map(|(to, from)| (from.index(), to))
.collect();
let mapped_exprs: Vec<PhysicalExprRef> = exprs
.iter()
.map(|expr| {
expr.clone()
.transform_down(&|node: PhysicalExprRef| {
Ok(Transformed::yes(
if let Some(column) = node.as_any().downcast_ref::<Column>() {
let mapped_idx = required_columns_mapping[&column.index()];
Arc::new(Column::new(column.name(), mapped_idx))
} else {
node
},
))
})
.map(|r| r.data)
})
.collect::<Result<_>>()?;
Ok((
mapped_exprs,
required_columns.into_iter().map(|c| c.index()).collect(),
))
}