blob: 5a4324e11c62a90c8f7b103b8217ce2abb03b89d [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{cmp::Ordering, pin::Pin, sync::Arc};
use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions};
use async_trait::async_trait;
use datafusion::{common::Result, physical_plan::metrics::Time};
use datafusion_ext_commons::suggested_output_batch_mem_size;
use crate::{
common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender},
compare_cursor, cur_forward,
joins::{Idx, JoinParams, StreamCursors},
sort_merge_join_exec::Joiner,
};
pub struct ExistenceJoiner {
join_params: JoinParams,
output_sender: Arc<WrappedRecordBatchSender>,
indices: Vec<Idx>,
exists: Vec<bool>,
send_output_time: Time,
output_rows: usize,
}
impl ExistenceJoiner {
pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self {
Self {
join_params,
output_sender,
indices: vec![],
exists: vec![],
send_output_time: Time::new(),
output_rows: 0,
}
}
fn should_flush(&self, curs: &StreamCursors) -> bool {
if self.indices.len() >= self.join_params.batch_size {
return true;
}
if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6
&& curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size()
{
if let Some(first_idx) = self.indices.first() {
if first_idx.0 < curs.0.cur_idx.0 {
return true;
}
}
}
false
}
async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
let indices = std::mem::take(&mut self.indices);
let num_rows = indices.len();
let cols = interleave_batches(curs.0.batch_schema.clone(), &curs.0.batches, &indices)?;
let exists = std::mem::take(&mut self.exists);
let exists_col: ArrayRef = Arc::new(arrow::array::BooleanArray::from(exists));
let output_batch = RecordBatch::try_new_with_options(
self.join_params.output_schema.clone(),
[cols.columns().to_vec(), vec![exists_col]].concat(),
&RecordBatchOptions::new().with_row_count(Some(num_rows)),
)?;
if output_batch.num_rows() > 0 {
self.output_rows += output_batch.num_rows();
let timer = self.send_output_time.timer();
self.output_sender.send(Ok(output_batch), None).await;
drop(timer);
}
Ok(())
}
async fn join_less(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
self.indices.push(curs.0.cur_idx);
self.exists.push(false);
cur_forward!(curs.0);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
curs.0
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx));
Ok(())
}
async fn join_greater(self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
cur_forward!(curs.1);
curs.1
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.1.cur_idx));
Ok(())
}
async fn join_equal(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
let mut lidx = curs.0.cur_idx;
let mut ridx = curs.1.cur_idx;
loop {
self.indices.push(lidx);
self.exists.push(true);
cur_forward!(curs.0);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
curs.0
.set_min_reserved_idx(*self.indices.first().unwrap_or(&lidx));
if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) {
lidx = curs.0.cur_idx;
continue;
}
break;
}
// skip all right equal rows
loop {
cur_forward!(curs.1);
curs.1.set_min_reserved_idx(ridx);
if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) {
ridx = curs.1.cur_idx;
continue;
}
break;
}
Ok(())
}
}
#[async_trait]
impl Joiner for ExistenceJoiner {
async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
while !curs.0.finished && !curs.1.finished {
match compare_cursor!(curs) {
Ordering::Less => {
self.as_mut().join_less(curs).await?;
}
Ordering::Greater => {
self.as_mut().join_greater(curs).await?;
}
Ordering::Equal => {
self.as_mut().join_equal(curs).await?;
}
}
}
while !curs.0.finished {
self.as_mut().join_less(curs).await?;
}
if !self.indices.is_empty() {
self.flush(curs).await?;
}
Ok(())
}
fn total_send_output_time(&self) -> usize {
self.send_output_time.value()
}
fn num_output_rows(&self) -> usize {
self.output_rows
}
}