blob: cc7b5d4bb358946ac3cc83146a7c9cb58af9c3ba [file]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
any::Any,
fmt::{Debug, Formatter},
future::Future,
pin::Pin,
sync::{Arc, Weak},
time::{Duration, Instant},
};
use arrow::{
array::RecordBatch,
compute::SortOptions,
datatypes::{DataType, SchemaRef},
};
use async_trait::async_trait;
use datafusion::{
common::{JoinSide, Result, Statistics},
execution::context::TaskContext,
physical_expr::{PhysicalExprRef, PhysicalSortExpr},
physical_plan::{
joins::utils::JoinOn,
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, Time},
stream::RecordBatchStreamAdapter,
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream,
},
};
use datafusion_ext_commons::{
batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput,
};
use futures::{StreamExt, TryStreamExt};
use hashbrown::HashMap;
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use crate::{
common::{
batch_statisitcs::{stat_input, InputBatchStatistics},
output::{TaskOutputter, WrappedRecordBatchSender},
},
joins::{
bhj::{
full_join::{
LProbedFullOuterJoiner, LProbedInnerJoiner, LProbedLeftJoiner, LProbedRightJoiner,
RProbedFullOuterJoiner, RProbedInnerJoiner, RProbedLeftJoiner, RProbedRightJoiner,
},
semi_join::{
LProbedExistenceJoiner, LProbedLeftAntiJoiner, LProbedLeftSemiJoiner,
LProbedRightAntiJoiner, LProbedRightSemiJoiner, RProbedExistenceJoiner,
RProbedLeftAntiJoiner, RProbedLeftSemiJoiner, RProbedRightAntiJoiner,
RProbedRightSemiJoiner,
},
},
join_hash_map::JoinHashMap,
join_utils::{JoinType, JoinType::*},
JoinParams,
},
};
#[derive(Debug)]
pub struct BroadcastJoinExec {
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: JoinType,
broadcast_side: JoinSide,
schema: SchemaRef,
cached_build_hash_map_id: Option<String>,
metrics: ExecutionPlanMetricsSet,
}
impl BroadcastJoinExec {
pub fn try_new(
schema: SchemaRef,
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: JoinType,
broadcast_side: JoinSide,
cached_build_hash_map_id: Option<String>,
) -> Result<Self> {
Ok(Self {
left,
right,
on,
join_type,
broadcast_side,
schema,
cached_build_hash_map_id,
metrics: ExecutionPlanMetricsSet::new(),
})
}
fn create_join_params(&self) -> Result<JoinParams> {
let left_schema = self.left.schema();
let right_schema = self.right.schema();
let (left_keys, right_keys): (Vec<PhysicalExprRef>, Vec<PhysicalExprRef>) =
self.on.iter().cloned().unzip();
let key_data_types: Vec<DataType> = self
.on
.iter()
.map(|(left_key, right_key)| {
Ok({
let left_dt = left_key.data_type(&left_schema)?;
let right_dt = right_key.data_type(&right_schema)?;
if left_dt != right_dt {
df_execution_err!(
"join key data type differs {left_dt:?} <-> {right_dt:?}"
)?;
}
left_dt
})
})
.collect::<Result<_>>()?;
// use smaller batch size and coalesce batches at the end, to avoid buffer
// overflowing
let batch_size = batch_size();
let sub_batch_size = batch_size / batch_size.ilog10() as usize;
Ok(JoinParams {
join_type: self.join_type,
left_schema,
right_schema,
output_schema: self.schema(),
left_keys,
right_keys,
batch_size: sub_batch_size,
sort_options: vec![SortOptions::default(); self.on.len()],
key_data_types,
})
}
}
impl ExecutionPlan for BroadcastJoinExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
match self.broadcast_side {
JoinSide::Left => self.right.output_partitioning(),
JoinSide::Right => self.left.output_partitioning(),
}
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.left.clone(), self.right.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::try_new(
self.schema.clone(),
children[0].clone(),
children[1].clone(),
self.on.iter().cloned().collect(),
self.join_type,
self.broadcast_side,
None,
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition));
let join_params = self.create_join_params()?;
let left = self.left.execute(partition, context.clone())?;
let right = self.right.execute(partition, context.clone())?;
let broadcast_side = self.broadcast_side;
let cached_build_hash_map_id = self.cached_build_hash_map_id.clone();
let output_schema = self.schema();
// stat probed side
let input_batch_stat =
InputBatchStatistics::from_metrics_set_and_blaze_conf(&self.metrics, partition)?;
let (left, right) = match broadcast_side {
JoinSide::Left => (left, stat_input(input_batch_stat, right)?),
JoinSide::Right => (stat_input(input_batch_stat, left)?, right),
};
let metrics_cloned = metrics.clone();
let context_cloned = context.clone();
let output_stream = Box::pin(RecordBatchStreamAdapter::new(
output_schema.clone(),
futures::stream::once(async move {
context_cloned.output_with_sender("BroadcastJoin", output_schema, move |sender| {
execute_join(
left,
right,
join_params,
broadcast_side,
cached_build_hash_map_id,
metrics_cloned,
sender,
)
})
})
.try_flatten(),
));
Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?)
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Result<Statistics> {
unimplemented!()
}
}
impl DisplayAs for BroadcastJoinExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "BroadcastJoin")
}
}
async fn execute_join(
left: SendableRecordBatchStream,
right: SendableRecordBatchStream,
join_params: JoinParams,
broadcast_side: JoinSide,
cached_build_hash_map_id: Option<String>,
metrics: Arc<BaselineMetrics>,
sender: Arc<WrappedRecordBatchSender>,
) -> Result<()> {
let start_time = Instant::now();
let mut excluded_time_ns = 0;
let poll_time = Time::new();
let (mut probed, _keys, mut joiner): (_, _, Pin<Box<dyn Joiner + Send>>) = match broadcast_side
{
JoinSide::Left => {
let right_schema = right.schema();
let mut right_peeked = Box::pin(right.peekable());
let (_, lmap_result) = futures::join!(
// fetch two sides asynchronously
async {
let timer = poll_time.timer();
right_peeked.as_mut().peek().await;
drop(timer);
},
collect_join_hash_map(
cached_build_hash_map_id,
left,
&join_params.left_keys,
poll_time.clone(),
),
);
let lmap = lmap_result?;
(
Box::pin(RecordBatchStreamAdapter::new(right_schema, right_peeked)),
join_params.right_keys.clone(),
match join_params.join_type {
Inner => Box::pin(RProbedInnerJoiner::new(join_params, lmap, sender)),
Left => Box::pin(RProbedLeftJoiner::new(join_params, lmap, sender)),
Right => Box::pin(RProbedRightJoiner::new(join_params, lmap, sender)),
Full => Box::pin(RProbedFullOuterJoiner::new(join_params, lmap, sender)),
LeftSemi => Box::pin(RProbedLeftSemiJoiner::new(join_params, lmap, sender)),
LeftAnti => Box::pin(RProbedLeftAntiJoiner::new(join_params, lmap, sender)),
RightSemi => Box::pin(RProbedRightSemiJoiner::new(join_params, lmap, sender)),
RightAnti => Box::pin(RProbedRightAntiJoiner::new(join_params, lmap, sender)),
Existence => Box::pin(RProbedExistenceJoiner::new(join_params, lmap, sender)),
},
)
}
JoinSide::Right => {
let left_schema = left.schema();
let mut left_peeked = Box::pin(left.peekable());
let (_, rmap_result) = futures::join!(
// fetch two sides asynchronizely
async {
let timer = poll_time.timer();
left_peeked.as_mut().peek().await;
drop(timer);
},
collect_join_hash_map(
cached_build_hash_map_id,
right,
&join_params.right_keys,
poll_time.clone(),
),
);
let rmap = rmap_result?;
(
Box::pin(RecordBatchStreamAdapter::new(left_schema, left_peeked)),
join_params.left_keys.clone(),
match join_params.join_type {
Inner => Box::pin(LProbedInnerJoiner::new(join_params, rmap, sender)),
Left => Box::pin(LProbedLeftJoiner::new(join_params, rmap, sender)),
Right => Box::pin(LProbedRightJoiner::new(join_params, rmap, sender)),
Full => Box::pin(LProbedFullOuterJoiner::new(join_params, rmap, sender)),
LeftSemi => Box::pin(LProbedLeftSemiJoiner::new(join_params, rmap, sender)),
LeftAnti => Box::pin(LProbedLeftAntiJoiner::new(join_params, rmap, sender)),
RightSemi => Box::pin(LProbedRightSemiJoiner::new(join_params, rmap, sender)),
RightAnti => Box::pin(LProbedRightAntiJoiner::new(join_params, rmap, sender)),
Existence => Box::pin(LProbedExistenceJoiner::new(join_params, rmap, sender)),
},
)
}
};
while let Some(batch) = {
let timer = poll_time.timer();
let batch = probed.next().await.transpose()?;
drop(timer);
batch
} {
joiner.as_mut().join(batch).await?;
}
joiner.as_mut().finish().await?;
metrics.record_output(joiner.num_output_rows());
excluded_time_ns += poll_time.value();
excluded_time_ns += joiner.total_send_output_time();
// discount poll input and send output batch time
let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64;
join_time_ns -= excluded_time_ns as u64;
metrics
.elapsed_compute()
.add_duration(Duration::from_nanos(join_time_ns));
Ok(())
}
async fn collect_join_hash_map(
cached_build_hash_map_id: Option<String>,
input: SendableRecordBatchStream,
key_exprs: &[PhysicalExprRef],
poll_time: Time,
) -> Result<Arc<JoinHashMap>> {
Ok(match cached_build_hash_map_id {
Some(cached_id) => {
get_cached_join_hash_map(&cached_id, || async {
collect_join_hash_map_without_caching(input, key_exprs, poll_time).await
})
.await?
}
None => {
let map = collect_join_hash_map_without_caching(input, key_exprs, poll_time).await?;
Arc::new(map)
}
})
}
async fn collect_join_hash_map_without_caching(
mut input: SendableRecordBatchStream,
key_exprs: &[PhysicalExprRef],
poll_time: Time,
) -> Result<JoinHashMap> {
let mut hash_map_batches = vec![];
while let Some(batch) = {
let timer = poll_time.timer();
let batch = input.next().await.transpose()?;
drop(timer);
batch
} {
hash_map_batches.push(batch);
}
match hash_map_batches.len() {
0 => Ok(JoinHashMap::try_new_empty(input.schema(), key_exprs)?),
1 => Ok(JoinHashMap::try_from_hash_map_batch(
hash_map_batches[0].clone(),
key_exprs,
)?),
n => df_execution_err!("expect zero or one hash map batch, got {n}"),
}
}
#[async_trait]
pub trait Joiner {
async fn join(self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()>;
async fn finish(self: Pin<&mut Self>) -> Result<()>;
fn total_send_output_time(&self) -> usize;
fn num_output_rows(&self) -> usize;
}
async fn get_cached_join_hash_map<Fut: Future<Output = Result<JoinHashMap>> + Send>(
cached_id: &str,
init: impl FnOnce() -> Fut,
) -> Result<Arc<JoinHashMap>> {
type Slot = Arc<tokio::sync::Mutex<Weak<JoinHashMap>>>;
static CACHED_JOIN_HASH_MAP: OnceCell<Arc<Mutex<HashMap<String, Slot>>>> = OnceCell::new();
// TODO: remove expired keys from cached join hash map
let cached_join_hash_map = CACHED_JOIN_HASH_MAP.get_or_init(|| Arc::default());
let slot = cached_join_hash_map
.lock()
.entry(cached_id.to_string())
.or_default()
.clone();
let mut slot = slot.lock().await;
if let Some(cached) = slot.upgrade() {
Ok(cached)
} else {
let new = Arc::new(init().await?);
*slot = Arc::downgrade(&new);
Ok(new)
}
}