blob: a0fedce32525c23a0d7389f363c9f4a35ef39347 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
},
};
use arrow::{array::RecordBatch, buffer::NullBuffer, datatypes::Schema, row::Rows};
use async_trait::async_trait;
use bitvec::{bitvec, prelude::BitVec};
use datafusion::{common::Result, physical_plan::metrics::Time};
use crate::{
broadcast_join_exec::Joiner,
common::{
batch_selection::{interleave_batches, take_batch_opt},
output::WrappedRecordBatchSender,
},
joins::{
bhj::{
full_join::ProbeSide::{L, R},
ProbeSide,
},
join_hash_map::JoinHashMap,
Idx, JoinParams,
},
};
#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)]
pub struct JoinerParams {
probe_side: ProbeSide,
probe_side_outer: bool,
build_side_outer: bool,
}
impl JoinerParams {
const fn new(probe_side: ProbeSide, probe_side_outer: bool, build_side_outer: bool) -> Self {
Self {
probe_side,
probe_side_outer,
build_side_outer,
}
}
}
const LEFT_PROBED_INNER: JoinerParams = JoinerParams::new(L, false, false);
const LEFT_PROBED_LEFT: JoinerParams = JoinerParams::new(L, true, false);
const LEFT_PROBED_RIGHT: JoinerParams = JoinerParams::new(L, false, true);
const LEFT_PROBED_OUTER: JoinerParams = JoinerParams::new(L, true, true);
const RIGHT_PROBED_INNER: JoinerParams = JoinerParams::new(R, false, false);
const RIGHT_PROBED_LEFT: JoinerParams = JoinerParams::new(R, false, true);
const RIGHT_PROBED_RIGHT: JoinerParams = JoinerParams::new(R, true, false);
const RIGHT_PROBED_OUTER: JoinerParams = JoinerParams::new(R, true, true);
pub type LeftProbedInnerJoiner = FullJoiner<LEFT_PROBED_INNER>;
pub type LeftProbedLeftJoiner = FullJoiner<LEFT_PROBED_LEFT>;
pub type LeftProbedRightJoiner = FullJoiner<LEFT_PROBED_RIGHT>;
pub type LeftProbedFullOuterJoiner = FullJoiner<LEFT_PROBED_OUTER>;
pub type RightProbedInnerJoiner = FullJoiner<RIGHT_PROBED_INNER>;
pub type RightProbedLeftJoiner = FullJoiner<RIGHT_PROBED_LEFT>;
pub type RightProbedRightJoiner = FullJoiner<RIGHT_PROBED_RIGHT>;
pub type RightProbedFullOuterJoiner = FullJoiner<RIGHT_PROBED_OUTER>;
pub struct FullJoiner<const P: JoinerParams> {
join_params: JoinParams,
output_sender: Arc<WrappedRecordBatchSender>,
map: JoinHashMap,
map_joined: Vec<BitVec>,
send_output_time: Time,
output_rows: AtomicUsize,
}
impl<const P: JoinerParams> FullJoiner<P> {
pub fn new(
join_params: JoinParams,
map: JoinHashMap,
output_sender: Arc<WrappedRecordBatchSender>,
) -> Self {
let map_joined = map
.data_batches()
.iter()
.map(|batch| bitvec![0; batch.num_rows()])
.collect();
Self {
join_params,
output_sender,
map,
map_joined,
send_output_time: Time::default(),
output_rows: AtomicUsize::new(0),
}
}
async fn flush(
&self,
probed_batch: RecordBatch,
probe_indices: Vec<Option<u32>>,
build_indices: Vec<Idx>,
) -> Result<()> {
let pcols = take_batch_opt(probed_batch.clone(), probe_indices)?;
let bcols = interleave_batches(
self.map.data_schema(),
self.map.data_batches(),
&build_indices,
)?;
let output_batch = RecordBatch::try_new(
self.join_params.output_schema.clone(),
match P.probe_side {
L => [pcols.columns().to_vec(), bcols.columns().to_vec()].concat(),
R => [bcols.columns().to_vec(), pcols.columns().to_vec()].concat(),
},
)?;
self.output_rows.fetch_add(output_batch.num_rows(), Relaxed);
let timer = self.send_output_time.timer();
self.output_sender.send(Ok(output_batch), None).await;
drop(timer);
Ok(())
}
}
#[async_trait]
impl<const P: JoinerParams> Joiner for FullJoiner<P> {
async fn join(
mut self: Pin<&mut Self>,
probed_batch: RecordBatch,
probed_key: Rows,
probed_key_has_null: Option<NullBuffer>,
) -> Result<()> {
let mut probe_indices: Vec<Option<u32>> = vec![];
let mut build_indices: Vec<Idx> = vec![];
for (row_idx, row) in probed_key.iter().enumerate() {
let mut found = false;
if !probed_key_has_null
.as_ref()
.map(|nb| nb.is_null(row_idx))
.unwrap_or(false)
{
for idx in self.map.search(&row) {
found = true;
probe_indices.push(Some(row_idx as u32));
build_indices.push(idx);
if P.build_side_outer {
// safety: bypass mutability checker
let map_joined = unsafe {
&mut *(&self.map_joined[idx.0] as *const BitVec as *mut BitVec)
};
map_joined.set(idx.1, true);
}
}
}
if P.probe_side_outer && !found {
probe_indices.push(Some(row_idx as u32));
build_indices.push(Idx::default());
}
if probe_indices.len() >= self.join_params.batch_size {
let probe_indices = std::mem::take(&mut probe_indices);
let build_indices = std::mem::take(&mut build_indices);
self.as_mut()
.flush(probed_batch.clone(), probe_indices, build_indices)
.await?;
}
}
if !probe_indices.is_empty() {
self.flush(probed_batch.clone(), probe_indices, build_indices)
.await?;
}
Ok(())
}
async fn finish(mut self: Pin<&mut Self>) -> Result<()> {
if P.build_side_outer {
let probed_schema = match P.probe_side {
L => self.join_params.left_schema.clone(),
R => self.join_params.right_schema.clone(),
};
let probed_empty_batch = RecordBatch::new_empty(Arc::new(Schema::new(
probed_schema
.fields()
.iter()
.map(|field| field.as_ref().clone().with_nullable(true))
.collect::<Vec<_>>(),
)));
let map_joined = std::mem::take(&mut self.map_joined);
for (batch_idx, batch_joined) in map_joined.into_iter().enumerate() {
let mut probe_indices: Vec<Option<u32>> = Vec::with_capacity(batch_joined.len());
let mut build_indices: Vec<Idx> = Vec::with_capacity(batch_joined.len());
for (row_idx, joined) in batch_joined.into_iter().enumerate() {
if !joined && self.map.inserted(batch_idx, row_idx) {
probe_indices.push(None);
build_indices.push((batch_idx, row_idx));
}
}
if !probe_indices.is_empty() {
self.as_mut()
.flush(probed_empty_batch.clone(), probe_indices, build_indices)
.await?;
}
}
}
Ok(())
}
fn total_send_output_time(&self) -> usize {
self.send_output_time.value()
}
fn num_output_rows(&self) -> usize {
self.output_rows.load(Relaxed)
}
}