blob: b52e77f00eb6b041d599f0d6a29e9842366134c8 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{any::Any, fmt::Formatter, sync::Arc};
use arrow::datatypes::SchemaRef;
use datafusion::{
common::{JoinType, Result, Statistics},
execution::{SendableRecordBatchStream, TaskContext},
physical_expr::{Partitioning, PhysicalSortExpr},
physical_plan::{
joins::{
utils::{build_join_schema, check_join_is_valid, JoinFilter},
NestedLoopJoinExec,
},
memory::MemoryExec,
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
stream::RecordBatchStreamAdapter,
DisplayAs, DisplayFormatType, ExecutionPlan,
},
};
use datafusion_ext_commons::batch_size;
use futures::{stream::once, StreamExt, TryStreamExt};
use parking_lot::Mutex;
use crate::broadcast_join_exec::RecordBatchStreamsWrapperExec;
#[derive(Debug)]
pub struct BroadcastNestedLoopJoinExec {
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
join_type: JoinType,
filter: Option<JoinFilter>,
schema: SchemaRef,
metrics: ExecutionPlanMetricsSet,
}
impl BroadcastNestedLoopJoinExec {
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
join_type: JoinType,
filter: Option<JoinFilter>,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &[])?;
let (schema, _column_indices) = build_join_schema(&left_schema, &right_schema, &join_type);
Ok(Self {
left,
right,
filter,
join_type,
schema: Arc::new(schema),
metrics: ExecutionPlanMetricsSet::new(),
})
}
}
impl DisplayAs for BroadcastNestedLoopJoinExec {
fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "BroadcastNestedLoopJoin")
}
}
impl ExecutionPlan for BroadcastNestedLoopJoinExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
if left_is_build_side(self.join_type) {
self.right.output_partitioning()
} else {
self.left.output_partitioning()
}
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.left.clone(), self.right.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::try_new(
children[0].clone(),
children[1].clone(),
self.join_type,
self.filter.clone(),
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let joined = Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
once(execute_join(
partition,
context,
self.left.clone(),
self.right.clone(),
self.join_type,
self.filter.clone(),
self.metrics.clone(),
))
.try_flatten(),
));
Ok(joined)
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Result<Statistics> {
todo!()
}
}
async fn execute_join(
partition: usize,
context: Arc<TaskContext>,
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
join_type: JoinType,
filter: Option<JoinFilter>,
metrics: ExecutionPlanMetricsSet,
) -> Result<SendableRecordBatchStream> {
// inner side
let mut inner_stream = if left_is_build_side(join_type) {
left.execute(partition, context.clone())?
} else {
right.execute(partition, context.clone())?
};
let inner_schema = inner_stream.schema();
let mut inner_batches = vec![];
while let Some(batch) = inner_stream.next().await.transpose()? {
inner_batches.push(batch);
}
let inner_batch_max_num_rows = inner_batches
.iter()
.map(|batch| batch.num_rows())
.max()
.unwrap_or(0);
let inner_batch_max_mem_size = inner_batches
.iter()
.map(|batch| batch.get_array_memory_size())
.max()
.unwrap_or(0);
let target_output_num_rows = batch_size();
let target_output_mem_size = 1 << 26; // 64MB
let inner_exec: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(&[inner_batches], inner_schema, None)?);
// outer side
let (outer_schema, outer_partitioning, outer_stream) = if left_is_build_side(join_type) {
(
right.schema(),
right.output_partitioning(),
right.execute(partition, context.clone())?,
)
} else {
(
left.schema(),
left.output_partitioning(),
left.execute(partition, context.clone())?,
)
};
let chunked_outer_stream = Box::pin(RecordBatchStreamAdapter::new(
outer_schema.clone(),
outer_stream.flat_map(move |batch_result| match batch_result {
Ok(batch) => {
let batch_num_rows = batch.num_rows();
let batch_mem_size = batch.get_array_memory_size();
let output_num_rows = batch_num_rows * inner_batch_max_num_rows;
let output_mem_size = batch_num_rows * inner_batch_max_mem_size
+ batch_mem_size * inner_batch_max_num_rows;
let chunk_count = std::cmp::min(
(output_num_rows / target_output_num_rows).max(1),
(output_mem_size / target_output_mem_size).max(1),
);
let chunk_len = (batch_num_rows / chunk_count).max(1);
let mut chunks = vec![];
for beg in (0..batch.num_rows()).step_by(chunk_len) {
chunks.push(Ok(batch.slice(beg, chunk_len.min(batch.num_rows() - beg))));
}
futures::stream::iter(chunks)
}
Err(err) => futures::stream::iter(vec![Err(err)]),
}),
));
let outer_exec: Arc<dyn ExecutionPlan> = Arc::new(RecordBatchStreamsWrapperExec {
schema: outer_schema,
stream: Mutex::new(Some(chunked_outer_stream)),
output_partitioning: outer_partitioning,
});
// join with datafusion's builtin NestedLoopJoinExec
let nlj = if left_is_build_side(join_type) {
NestedLoopJoinExec::try_new(inner_exec, outer_exec, filter, &join_type)?
} else {
NestedLoopJoinExec::try_new(outer_exec, inner_exec, filter, &join_type)?
};
let joined = nlj.execute(partition, context)?;
let baseline_metrics = BaselineMetrics::new(&metrics, partition);
let output_stream = Box::pin(RecordBatchStreamAdapter::new(
joined.schema(),
joined.map(move |batch_result| {
if let Ok(batch) = &batch_result {
baseline_metrics.record_output(batch.num_rows());
}
batch_result
}),
));
Ok(output_stream)
}
fn left_is_build_side(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full
)
}