supports BHJ in blaze
diff --git a/Cargo.lock b/Cargo.lock
index ef23afd..51ae230 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -923,6 +923,7 @@
"arrow",
"async-trait",
"base64 0.22.1",
+ "bitvec",
"blaze-jni-bridge",
"byteorder",
"bytes",
diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto
index c1d2c58..97e3bc3 100644
--- a/native-engine/blaze-serde/proto/blaze.proto
+++ b/native-engine/blaze-serde/proto/blaze.proto
@@ -35,19 +35,20 @@
FilterExecNode filter = 8;
UnionExecNode union = 9;
SortMergeJoinExecNode sort_merge_join = 10;
- BroadcastJoinExecNode broadcast_join = 11;
- RenameColumnsExecNode rename_columns = 12;
- EmptyPartitionsExecNode empty_partitions = 13;
- AggExecNode agg = 14;
- LimitExecNode limit = 15;
- FFIReaderExecNode ffi_reader = 16;
- CoalesceBatchesExecNode coalesce_batches = 17;
- ExpandExecNode expand = 18;
- RssShuffleWriterExecNode rss_shuffle_writer= 19;
- WindowExecNode window = 20;
- GenerateExecNode generate = 21;
- ParquetSinkExecNode parquet_sink = 22;
- BroadcastNestedLoopJoinExecNode broadcast_nested_loop_join = 23;
+ BroadcastJoinBuildHashMapExecNode broadcast_join_build_hash_map = 11;
+ BroadcastJoinExecNode broadcast_join = 12;
+ RenameColumnsExecNode rename_columns = 13;
+ EmptyPartitionsExecNode empty_partitions = 14;
+ AggExecNode agg = 15;
+ LimitExecNode limit = 16;
+ FFIReaderExecNode ffi_reader = 17;
+ CoalesceBatchesExecNode coalesce_batches = 18;
+ ExpandExecNode expand = 19;
+ RssShuffleWriterExecNode rss_shuffle_writer= 20;
+ WindowExecNode window = 21;
+ GenerateExecNode generate = 22;
+ ParquetSinkExecNode parquet_sink = 23;
+ BroadcastNestedLoopJoinExecNode broadcast_nested_loop_join = 24;
}
}
@@ -407,12 +408,18 @@
JoinFilter join_filter = 7;
}
+message BroadcastJoinBuildHashMapExecNode {
+ PhysicalPlanNode input = 1;
+ repeated PhysicalExprNode keys =2;
+}
+
message BroadcastJoinExecNode {
- PhysicalPlanNode left = 1;
- PhysicalPlanNode right = 2;
- repeated JoinOn on = 3;
- JoinType join_type = 4;
- JoinFilter join_filter = 5;
+ Schema schema = 1;
+ PhysicalPlanNode left = 2;
+ PhysicalPlanNode right = 3;
+ repeated JoinOn on = 4;
+ JoinType join_type = 5;
+ JoinSide broadcast_side = 6;
}
message BroadcastNestedLoopJoinExecNode {
diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs
index 5d78ca2..784f175 100644
--- a/native-engine/blaze-serde/src/from_proto.rs
+++ b/native-engine/blaze-serde/src/from_proto.rs
@@ -61,9 +61,9 @@
use datafusion_ext_plans::{
agg::{create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr},
agg_exec::AggExec,
+ broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec,
broadcast_join_exec::BroadcastJoinExec,
broadcast_nested_loop_join_exec::BroadcastNestedLoopJoinExec,
- common::join_utils::JoinType,
debug_exec::DebugExec,
empty_partitions_exec::EmptyPartitionsExec,
expand_exec::ExpandExec,
@@ -73,6 +73,7 @@
generate_exec::GenerateExec,
ipc_reader_exec::IpcReaderExec,
ipc_writer_exec::IpcWriterExec,
+ joins::join_utils::JoinType,
limit_exec::LimitExec,
parquet_exec::ParquetExec,
parquet_sink_exec::ParquetSinkExec,
@@ -284,7 +285,7 @@
self
))
})?;
- if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr {
+ if let ExprType::Sort(sort_expr) = expr {
let expr = sort_expr
.expr
.as_ref()
@@ -320,7 +321,22 @@
sort.fetch_limit.as_ref().map(|limit| limit.limit as usize),
)))
}
+ PhysicalPlanType::BroadcastJoinBuildHashMap(bhm) => {
+ let input: Arc<dyn ExecutionPlan> = convert_box_required!(bhm.input)?;
+ let keys = bhm
+ .keys
+ .iter()
+ .map(|expr| {
+ Ok(bind(
+ try_parse_physical_expr(expr, &input.schema())?,
+ &input.schema(),
+ )?)
+ })
+ .collect::<Result<Vec<Arc<dyn PhysicalExpr>>, Self::Error>>()?;
+ Ok(Arc::new(BroadcastJoinBuildHashMapExec::new(input, keys)))
+ }
PhysicalPlanType::BroadcastJoin(broadcast_join) => {
+ let schema = Arc::new(convert_required!(broadcast_join.schema)?);
let left: Arc<dyn ExecutionPlan> = convert_box_required!(broadcast_join.left)?;
let right: Arc<dyn ExecutionPlan> = convert_box_required!(broadcast_join.right)?;
let on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = broadcast_join
@@ -339,44 +355,21 @@
let join_type = protobuf::JoinType::try_from(broadcast_join.join_type)
.expect("invalid JoinType");
- let join_filter = broadcast_join
- .join_filter
- .as_ref()
- .map(|f| {
- let schema = Arc::new(convert_required!(f.schema)?);
- let expression = try_parse_physical_expr_required(&f.expression, &schema)?;
- let column_indices = f
- .column_indices
- .iter()
- .map(|i| {
- let side =
- protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide");
- Ok(ColumnIndex {
- index: i.index as usize,
- side: side.into(),
- })
- })
- .collect::<Result<Vec<_>, PlanSerDeError>>()?;
- Ok(JoinFilter::new(
- bind(expression, &schema)?,
- column_indices,
- schema.as_ref().clone(),
- ))
- })
- .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?;
+ let broadcast_side = protobuf::JoinSide::try_from(broadcast_join.broadcast_side)
+ .expect("invalid BroadcastSide");
- let blaze_join_type: JoinType = join_type
- .try_into()
- .map_err(|_| proto_error("invalid JoinType"))?;
Ok(Arc::new(BroadcastJoinExec::try_new(
+ schema,
left,
right,
on,
- blaze_join_type
+ join_type
.try_into()
.map_err(|_| proto_error("invalid JoinType"))?,
- join_filter,
+ broadcast_side
+ .try_into()
+ .map_err(|_| proto_error("invalid BroadcastSide"))?,
)?))
}
PhysicalPlanType::BroadcastNestedLoopJoin(bnlj) => {
diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs
index 040616b..56cd4a6 100644
--- a/native-engine/blaze-serde/src/lib.rs
+++ b/native-engine/blaze-serde/src/lib.rs
@@ -16,7 +16,7 @@
use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema, TimeUnit};
use datafusion::{common::JoinSide, logical_expr::Operator, scalar::ScalarValue};
-use datafusion_ext_plans::{agg::AggFunction, common::join_utils::JoinType};
+use datafusion_ext_plans::{agg::AggFunction, joins::join_utils::JoinType};
use crate::error::PlanSerDeError;
diff --git a/native-engine/datafusion-ext-plans/Cargo.toml b/native-engine/datafusion-ext-plans/Cargo.toml
index a233274..82fad59 100644
--- a/native-engine/datafusion-ext-plans/Cargo.toml
+++ b/native-engine/datafusion-ext-plans/Cargo.toml
@@ -11,6 +11,7 @@
arrow = { workspace = true }
async-trait = "0.1.80"
base64 = "0.22.1"
+bitvec = "1.0.1"
byteorder = "1.5.0"
bytes = "1.6.0"
blaze-jni-bridge = { workspace = true }
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs
new file mode 100644
index 0000000..ae9ac36
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs
@@ -0,0 +1,172 @@
+// Copyright 2022 The Blaze Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+ any::Any,
+ fmt::{Debug, Formatter},
+ sync::Arc,
+};
+
+use arrow::{
+ datatypes::SchemaRef,
+ row::{RowConverter, SortField},
+};
+use datafusion::{
+ common::Result,
+ error::DataFusionError,
+ execution::{SendableRecordBatchStream, TaskContext},
+ physical_expr::{Partitioning, PhysicalExpr, PhysicalSortExpr},
+ physical_plan::{
+ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
+ stream::RecordBatchStreamAdapter,
+ DisplayAs, DisplayFormatType, ExecutionPlan,
+ },
+};
+use futures::{stream::once, TryStreamExt};
+
+use crate::{
+ common::output::{NextBatchWithTimer, TaskOutputter},
+ joins::join_hash_map::{build_join_hash_map, join_hash_map_schema},
+};
+
+pub struct BroadcastJoinBuildHashMapExec {
+ input: Arc<dyn ExecutionPlan>,
+ keys: Vec<Arc<dyn PhysicalExpr>>,
+ metrics: ExecutionPlanMetricsSet,
+}
+
+impl BroadcastJoinBuildHashMapExec {
+ pub fn new(input: Arc<dyn ExecutionPlan>, keys: Vec<Arc<dyn PhysicalExpr>>) -> Self {
+ Self {
+ input,
+ keys,
+ metrics: ExecutionPlanMetricsSet::new(),
+ }
+ }
+}
+
+impl Debug for BroadcastJoinBuildHashMapExec {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "BroadcastJoinBuildHashMap [{:?}]", self.keys)
+ }
+}
+
+impl DisplayAs for BroadcastJoinBuildHashMapExec {
+ fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
+ write!(f, "BroadcastJoinBuildHashMapExec [{:?}]", self.keys)
+ }
+}
+
+impl ExecutionPlan for BroadcastJoinBuildHashMapExec {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ join_hash_map_schema(&self.input.schema())
+ }
+
+ fn output_partitioning(&self) -> Partitioning {
+ Partitioning::UnknownPartitioning(self.input.output_partitioning().partition_count())
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ None
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ vec![self.input.clone()]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ Ok(Arc::new(Self::new(children[0].clone(), self.keys.clone())))
+ }
+
+ fn execute(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
+ let input = self.input.execute(partition, context.clone())?;
+ Ok(Box::pin(RecordBatchStreamAdapter::new(
+ self.schema(),
+ once(execute_build_hash_map(
+ context,
+ input,
+ self.keys.clone(),
+ baseline_metrics,
+ ))
+ .try_flatten(),
+ )))
+ }
+
+ fn metrics(&self) -> Option<MetricsSet> {
+ Some(self.metrics.clone_inner())
+ }
+}
+
+async fn execute_build_hash_map(
+ context: Arc<TaskContext>,
+ mut input: SendableRecordBatchStream,
+ keys: Vec<Arc<dyn PhysicalExpr>>,
+ metrics: BaselineMetrics,
+) -> Result<SendableRecordBatchStream> {
+ let elapsed_compute = metrics.elapsed_compute().clone();
+ let mut timer = elapsed_compute.timer();
+
+ let mut data_batches = vec![];
+ let data_schema = input.schema();
+
+ // collect all input batches
+ while let Some(batch) = input.next_batch(Some(&mut timer)).await? {
+ data_batches.push(batch);
+ }
+
+ // evaluate keys
+ let key_row_converter = RowConverter::new(
+ keys.iter()
+ .map(|key| Ok(SortField::new(key.data_type(&data_schema)?)))
+ .collect::<Result<_>>()?,
+ )?;
+ let keys = data_batches
+ .iter()
+ .map(|batch| {
+ let key_columns = keys
+ .iter()
+ .map(|key| {
+ Ok::<_, DataFusionError>(key.evaluate(batch)?.into_array(batch.num_rows()))?
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(key_row_converter.convert_columns(&key_columns)?)
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ // build hash map
+ let hash_map_schema = join_hash_map_schema(&data_schema);
+ let hash_map = build_join_hash_map(&data_batches, data_schema, &keys)?;
+ drop(timer);
+
+ // output hash map batches as stream
+ context.output_with_sender("BuildHashMap", hash_map_schema, move |sender| async move {
+ let mut timer = elapsed_compute.timer();
+ for batch in hash_map.into_hash_map_batches() {
+ sender.send(Ok(batch), Some(&mut timer)).await;
+ }
+ Ok(())
+ })
+}
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
index 931cf3b..37d0d33 100644
--- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
+++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
@@ -15,90 +15,127 @@
use std::{
any::Any,
fmt::{Debug, Formatter},
+ pin::Pin,
sync::Arc,
- task::Poll,
- time::Duration,
+ time::{Duration, Instant},
};
-use arrow::{compute::SortOptions, datatypes::SchemaRef, record_batch::RecordBatch};
-use blaze_jni_bridge::{
- conf,
- conf::{BooleanConf, IntConf},
+use arrow::{
+ array::{Array, ArrayRef, RecordBatch},
+ buffer::NullBuffer,
+ compute::SortOptions,
+ datatypes::{DataType, SchemaRef},
+ row::{RowConverter, Rows, SortField},
};
+use async_trait::async_trait;
use datafusion::{
- common::{Result, Statistics},
+ common::{JoinSide, Result, Statistics},
execution::context::TaskContext,
- logical_expr::JoinType,
- physical_expr::PhysicalSortExpr,
+ physical_expr::{PhysicalExprRef, PhysicalSortExpr},
physical_plan::{
- expressions::Column,
- joins::{
- utils::{build_join_schema, check_join_is_valid, JoinFilter, JoinOn},
- HashJoinExec, PartitionMode,
- },
- memory::MemoryStream,
- metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
+ joins::utils::JoinOn,
+ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, Time},
stream::RecordBatchStreamAdapter,
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream,
},
};
-use datafusion_ext_commons::{df_execution_err, downcast_any};
-use futures::{stream::once, StreamExt, TryStreamExt};
-use parking_lot::Mutex;
+use datafusion_ext_commons::{
+ batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput,
+};
+use futures::{StreamExt, TryStreamExt};
-use crate::{sort_exec::SortExec, sort_merge_join_exec::SortMergeJoinExec};
+use crate::{
+ common::output::{TaskOutputter, WrappedRecordBatchSender},
+ joins::{
+ bhj::{
+ existence_join::{LeftProbedExistenceJoiner, RightProbedExistenceJoiner},
+ full_join::{
+ LeftProbedFullOuterJoiner, LeftProbedInnerJoiner, LeftProbedLeftJoiner,
+ LeftProbedRightJoiner, RightProbedFullOuterJoiner, RightProbedInnerJoiner,
+ RightProbedLeftJoiner, RightProbedRightJoiner,
+ },
+ semi_join::{
+ LeftProbedLeftAntiJoiner, LeftProbedLeftSemiJoiner, LeftProbedRightAntiJoiner,
+ LeftProbedRightSemiJoiner, RightProbedLeftAntiJoiner, RightProbedLeftSemiJoiner,
+ RightProbedRightAntiJoiner, RightProbedRightSemiJoiner,
+ },
+ },
+ join_hash_map::JoinHashMap,
+ join_utils::{JoinType, JoinType::*},
+ JoinParams,
+ },
+};
#[derive(Debug)]
pub struct BroadcastJoinExec {
- /// Left sorted joining execution plan
left: Arc<dyn ExecutionPlan>,
- /// Right sorting joining execution plan
right: Arc<dyn ExecutionPlan>,
- /// Set of common columns used to join on
on: JoinOn,
- /// How the join is performed
join_type: JoinType,
- /// Optional filter before outputting
- join_filter: Option<JoinFilter>,
- /// The schema once the join is applied
+ broadcast_side: JoinSide,
schema: SchemaRef,
- /// Execution metrics
metrics: ExecutionPlanMetricsSet,
}
impl BroadcastJoinExec {
pub fn try_new(
+ schema: SchemaRef,
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: JoinType,
- join_filter: Option<JoinFilter>,
+ broadcast_side: JoinSide,
) -> Result<Self> {
- if matches!(
- join_type,
- JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti,
- ) {
- if join_filter.is_some() {
- df_execution_err!("Semi/Anti join with filter is not supported yet")?;
- }
- }
-
- let left_schema = left.schema();
- let right_schema = right.schema();
-
- check_join_is_valid(&left_schema, &right_schema, &on)?;
- let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
-
Ok(Self {
left,
right,
on,
join_type,
- join_filter,
+ broadcast_side,
schema,
metrics: ExecutionPlanMetricsSet::new(),
})
}
+
+ fn create_join_params(&self) -> Result<JoinParams> {
+ let left_schema = self.left.schema();
+ let right_schema = self.right.schema();
+ let (left_keys, right_keys): (Vec<PhysicalExprRef>, Vec<PhysicalExprRef>) =
+ self.on.iter().cloned().unzip();
+ let key_data_types: Vec<DataType> = self
+ .on
+ .iter()
+ .map(|(left_key, right_key)| {
+ Ok({
+ let left_dt = left_key.data_type(&left_schema)?;
+ let right_dt = right_key.data_type(&right_schema)?;
+ if left_dt != right_dt {
+ df_execution_err!(
+ "join key data type differs {left_dt:?} <-> {right_dt:?}"
+ )?;
+ }
+ left_dt
+ })
+ })
+ .collect::<Result<_>>()?;
+
+ // use smaller batch size and coalesce batches at the end, to avoid buffer
+ // overflowing
+ let batch_size = batch_size();
+ let sub_batch_size = batch_size / batch_size.ilog2() as usize;
+
+ Ok(JoinParams {
+ join_type: self.join_type,
+ left_schema,
+ right_schema,
+ output_schema: self.schema(),
+ left_keys,
+ right_keys,
+ batch_size: sub_batch_size,
+ sort_options: vec![SortOptions::default(); self.on.len()],
+ key_data_types,
+ })
+ }
}
impl ExecutionPlan for BroadcastJoinExec {
@@ -127,11 +164,12 @@
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::try_new(
+ self.schema.clone(),
children[0].clone(),
children[1].clone(),
self.on.iter().cloned().collect(),
self.join_type,
- self.join_filter.clone(),
+ self.broadcast_side,
)?))
}
@@ -140,22 +178,32 @@
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
- let stream = execute_broadcast_join(
- self.schema(),
- self.left.clone(),
- self.right.clone(),
- partition,
- context,
- self.on.clone(),
- self.join_type,
- self.join_filter.clone(),
- BaselineMetrics::new(&self.metrics, partition),
- );
+ let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition));
+ let join_params = self.create_join_params()?;
+ let left = self.left.execute(partition, context.clone())?;
+ let right = self.right.execute(partition, context.clone())?;
+ let broadcast_side = self.broadcast_side;
+ let output_schema = self.schema();
- Ok(Box::pin(RecordBatchStreamAdapter::new(
- self.schema(),
- once(stream).try_flatten(),
- )))
+ let metrics_cloned = metrics.clone();
+ let context_cloned = context.clone();
+ let output_stream = Box::pin(RecordBatchStreamAdapter::new(
+ output_schema.clone(),
+ futures::stream::once(async move {
+ context_cloned.output_with_sender("BroadcastJoin", output_schema, move |sender| {
+ execute_join(
+ left,
+ right,
+ join_params,
+ broadcast_side,
+ metrics_cloned,
+ sender,
+ )
+ })
+ })
+ .try_flatten(),
+ ));
+ Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?)
}
fn metrics(&self) -> Option<MetricsSet> {
@@ -173,222 +221,133 @@
}
}
-async fn execute_broadcast_join(
- schema: SchemaRef,
- left: Arc<dyn ExecutionPlan>,
- right: Arc<dyn ExecutionPlan>,
- partition: usize,
- context: Arc<TaskContext>,
- on: JoinOn,
- join_type: JoinType,
- join_filter: Option<JoinFilter>,
- metrics: BaselineMetrics,
-) -> Result<SendableRecordBatchStream> {
- let enabled_fallback_to_smj = conf::BHJ_FALLBACKS_TO_SMJ_ENABLE.value()?;
- let bhj_num_rows_limit = conf::BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD.value()? as usize;
- let bhj_mem_size_limit = conf::BHJ_FALLBACKS_TO_SMJ_MEM_THRESHOLD.value()? as usize;
+async fn execute_join(
+ left: SendableRecordBatchStream,
+ right: SendableRecordBatchStream,
+ join_params: JoinParams,
+ broadcast_side: JoinSide,
+ metrics: Arc<BaselineMetrics>,
+ sender: Arc<WrappedRecordBatchSender>,
+) -> Result<()> {
+ let start_time = Instant::now();
+ let excluded_time_ns = 0;
+ let poll_time = Time::new();
- // if broadcasted size is small enough, use hash join
- // otherwise use sort-merge join
- #[derive(Debug)]
- enum JoinMode {
- Hash,
- SortMerge,
- }
- let mut join_mode = JoinMode::Hash;
-
- let left_schema = left.schema();
- let mut left = left;
-
- if enabled_fallback_to_smj {
- let mut left_stream = left.execute(0, context.clone())?.fuse();
- let mut left_cached: Vec<RecordBatch> = vec![];
- let mut left_num_rows = 0;
- let mut left_mem_size = 0;
-
- // read and cache batches from broadcasted side until reached limits
- while let Some(batch) = left_stream.next().await.transpose()? {
- left_num_rows += batch.num_rows();
- left_mem_size += batch.get_array_memory_size();
- left_cached.push(batch);
- if left_num_rows > bhj_num_rows_limit || left_mem_size > bhj_mem_size_limit {
- join_mode = JoinMode::SortMerge;
- break;
- }
+ let (mut probed, keys, mut joiner): (_, _, Pin<Box<dyn Joiner + Send>>) = match broadcast_side {
+ JoinSide::Left => {
+ let lmap = collect_join_hash_map(left, poll_time.clone()).await?;
+ (
+ right,
+ join_params.right_keys.clone(),
+ match join_params.join_type {
+ Inner => Box::pin(RightProbedInnerJoiner::new(join_params, lmap, sender)),
+ Left => Box::pin(RightProbedLeftJoiner::new(join_params, lmap, sender)),
+ Right => Box::pin(RightProbedRightJoiner::new(join_params, lmap, sender)),
+ Full => Box::pin(RightProbedFullOuterJoiner::new(join_params, lmap, sender)),
+ LeftSemi => Box::pin(RightProbedLeftSemiJoiner::new(join_params, lmap, sender)),
+ LeftAnti => Box::pin(RightProbedLeftAntiJoiner::new(join_params, lmap, sender)),
+ RightSemi => {
+ Box::pin(RightProbedRightSemiJoiner::new(join_params, lmap, sender))
+ }
+ RightAnti => {
+ Box::pin(RightProbedRightAntiJoiner::new(join_params, lmap, sender))
+ }
+ Existence => {
+ Box::pin(RightProbedExistenceJoiner::new(join_params, lmap, sender))
+ }
+ },
+ )
}
-
- // convert left cached and rest batches into execution plan
- let left_cached_stream: SendableRecordBatchStream = Box::pin(MemoryStream::try_new(
- left_cached,
- left_schema.clone(),
- None,
- )?);
- let left_rest_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
- left_schema.clone(),
- left_stream,
- ));
- let left_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
- left_schema.clone(),
- left_cached_stream.chain(left_rest_stream),
- ));
- left = Arc::new(RecordBatchStreamsWrapperExec {
- schema: left_schema.clone(),
- stream: Mutex::new(Some(left_stream)),
- output_partitioning: right.output_partitioning(),
- });
- }
-
- match join_mode {
- JoinMode::Hash => {
- let join = Arc::new(HashJoinExec::try_new(
- left.clone(),
- right.clone(),
- on,
- join_filter,
- &join_type,
- PartitionMode::CollectLeft,
- false,
- )?);
- log::info!("BroadcastJoin is using hash join mode: {:?}", &join);
-
- let join_schema = join.schema();
- let completed = join
- .execute(partition, context)?
- .chain(futures::stream::poll_fn(move |_| {
- // update metrics
- let join_metrics = join.metrics().unwrap();
- metrics.record_output(join_metrics.output_rows().unwrap_or(0));
- metrics.elapsed_compute().add_duration(Duration::from_nanos(
- [
- join_metrics
- .sum_by_name("build_time")
- .map(|v| v.as_usize() as u64),
- join_metrics
- .sum_by_name("join_time")
- .map(|v| v.as_usize() as u64),
- ]
- .into_iter()
- .flatten()
- .sum(),
- ));
- Poll::Ready(None)
- }));
- Ok(Box::pin(RecordBatchStreamAdapter::new(
- join_schema,
- completed,
- )))
+ JoinSide::Right => {
+ let rmap = collect_join_hash_map(right, poll_time.clone()).await?;
+ (
+ left,
+ join_params.left_keys.clone(),
+ match join_params.join_type {
+ Inner => Box::pin(LeftProbedInnerJoiner::new(join_params, rmap, sender)),
+ Left => Box::pin(LeftProbedLeftJoiner::new(join_params, rmap, sender)),
+ Right => Box::pin(LeftProbedRightJoiner::new(join_params, rmap, sender)),
+ Full => Box::pin(LeftProbedFullOuterJoiner::new(join_params, rmap, sender)),
+ LeftSemi => Box::pin(LeftProbedLeftSemiJoiner::new(join_params, rmap, sender)),
+ LeftAnti => Box::pin(LeftProbedLeftAntiJoiner::new(join_params, rmap, sender)),
+ RightSemi => {
+ Box::pin(LeftProbedRightSemiJoiner::new(join_params, rmap, sender))
+ }
+ RightAnti => {
+ Box::pin(LeftProbedRightAntiJoiner::new(join_params, rmap, sender))
+ }
+ Existence => {
+ Box::pin(LeftProbedExistenceJoiner::new(join_params, rmap, sender))
+ }
+ },
+ )
}
- JoinMode::SortMerge => {
- let sort_exprs: Vec<PhysicalSortExpr> = on
- .iter()
- .map(|(_col_left, col_right)| PhysicalSortExpr {
- expr: Arc::new(Column::new(
- "",
- downcast_any!(col_right, Column)
- .expect("requires column")
- .index(),
- )),
- options: Default::default(),
- })
- .collect();
+ };
- let right_sorted = Arc::new(SortExec::new(right, sort_exprs.clone(), None));
- let join = Arc::new(SortMergeJoinExec::try_new(
- schema,
- left.clone(),
- right_sorted.clone(),
- on,
- join_type.try_into()?,
- vec![SortOptions::default(); sort_exprs.len()],
- )?);
- log::info!("BroadcastJoin is using sort-merge join mode: {:?}", &join);
+ let probed_schema = probed.schema();
+ let key_converter = Box::pin(RowConverter::new(
+ keys.iter()
+ .map(|k| Ok(SortField::new(k.data_type(&probed_schema)?.clone())))
+ .collect::<Result<_>>()?,
+ )?);
- let join_schema = join.schema();
- let completed = join
- .execute(partition, context)?
- .chain(futures::stream::poll_fn(move |_| {
- // update metrics
- let right_sorted_metrics = right_sorted.metrics().unwrap();
- let join_metrics = join.metrics().unwrap();
- metrics.record_output(join_metrics.output_rows().unwrap_or(0));
- metrics.elapsed_compute().add_duration(Duration::from_nanos(
- [
- right_sorted_metrics.elapsed_compute(),
- join_metrics.elapsed_compute(),
- ]
- .into_iter()
- .flatten()
- .sum::<usize>() as u64,
- ));
- Poll::Ready(None)
- }));
- Ok(Box::pin(RecordBatchStreamAdapter::new(
- join_schema,
- completed,
- )))
- }
+ while let Some(batch) = {
+ let timer = poll_time.timer();
+ let batch = probed.next().await.transpose()?;
+ drop(timer);
+ batch
+ } {
+ let key_columns: Vec<ArrayRef> = keys
+ .iter()
+ .map(|key| Ok(key.evaluate(&batch)?.into_array(batch.num_rows())?))
+ .collect::<Result<_>>()?;
+ let key_rows = key_converter.convert_columns(&key_columns)?;
+ let key_has_nulls = key_columns
+ .iter()
+ .map(|c| c.nulls().cloned())
+ .reduce(|lhs, rhs| NullBuffer::union(lhs.as_ref(), rhs.as_ref()))
+ .unwrap_or(None);
+ joiner.as_mut().join(batch, key_rows, key_has_nulls).await?;
}
+ joiner.as_mut().finish().await?;
+ metrics.record_output(joiner.num_output_rows());
+
+ // discount poll input and send output batch time
+ let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64;
+ join_time_ns -= excluded_time_ns as u64;
+ metrics
+ .elapsed_compute()
+ .add_duration(Duration::from_nanos(join_time_ns));
+ Ok(())
}
-pub struct RecordBatchStreamsWrapperExec {
- pub schema: SchemaRef,
- pub stream: Mutex<Option<SendableRecordBatchStream>>,
- pub output_partitioning: Partitioning,
+async fn collect_join_hash_map(
+ mut input: SendableRecordBatchStream,
+ poll_time: Time,
+) -> Result<JoinHashMap> {
+ let mut hash_map_batches = vec![];
+ while let Some(batch) = {
+ let timer = poll_time.timer();
+ let batch = input.next().await.transpose()?;
+ drop(timer);
+ batch
+ } {
+ hash_map_batches.push(batch);
+ }
+ Ok(JoinHashMap::try_new(hash_map_batches)?)
}
-impl Debug for RecordBatchStreamsWrapperExec {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "RecordBatchStreamsWrapper")
- }
-}
+#[async_trait]
+pub trait Joiner {
+ async fn join(
+ self: Pin<&mut Self>,
+ probed_batch: RecordBatch,
+ probed_key: Rows,
+ probed_key_has_null: Option<NullBuffer>,
+ ) -> Result<()>;
-impl DisplayAs for RecordBatchStreamsWrapperExec {
- fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
- write!(f, "RecordBatchStreamsWrapper")
- }
-}
+ async fn finish(self: Pin<&mut Self>) -> Result<()>;
-impl ExecutionPlan for RecordBatchStreamsWrapperExec {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn schema(&self) -> SchemaRef {
- self.schema.clone()
- }
-
- fn output_partitioning(&self) -> Partitioning {
- self.output_partitioning.clone()
- }
-
- fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
- None
- }
-
- fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
- vec![]
- }
-
- fn with_new_children(
- self: Arc<Self>,
- _: Vec<Arc<dyn ExecutionPlan>>,
- ) -> Result<Arc<dyn ExecutionPlan>> {
- unimplemented!()
- }
-
- fn execute(
- &self,
- _partition: usize,
- _context: Arc<TaskContext>,
- ) -> Result<SendableRecordBatchStream> {
- let stream = std::mem::take(&mut *self.stream.lock());
- Ok(Box::pin(RecordBatchStreamAdapter::new(
- self.schema.clone(),
- Box::pin(futures::stream::iter(stream).flatten()),
- )))
- }
-
- fn statistics(&self) -> Result<Statistics> {
- unimplemented!()
- }
+ fn total_send_output_time(&self) -> usize;
+ fn num_output_rows(&self) -> usize;
}
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs
index b52e77f..7ddd741 100644
--- a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs
+++ b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs
@@ -34,8 +34,6 @@
use futures::{stream::once, StreamExt, TryStreamExt};
use parking_lot::Mutex;
-use crate::broadcast_join_exec::RecordBatchStreamsWrapperExec;
-
#[derive(Debug)]
pub struct BroadcastNestedLoopJoinExec {
left: Arc<dyn ExecutionPlan>,
@@ -250,3 +248,66 @@
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full
)
}
+
+struct RecordBatchStreamsWrapperExec {
+ pub schema: SchemaRef,
+ pub stream: Mutex<Option<SendableRecordBatchStream>>,
+ pub output_partitioning: Partitioning,
+}
+
+impl std::fmt::Debug for RecordBatchStreamsWrapperExec {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "RecordBatchStreamsWrapper")
+ }
+}
+
+impl DisplayAs for RecordBatchStreamsWrapperExec {
+ fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
+ write!(f, "RecordBatchStreamsWrapper")
+ }
+}
+
+impl ExecutionPlan for RecordBatchStreamsWrapperExec {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn output_partitioning(&self) -> Partitioning {
+ self.output_partitioning.clone()
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ None
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ vec![]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ _: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ unimplemented!()
+ }
+
+ fn execute(
+ &self,
+ _partition: usize,
+ _context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ let stream = std::mem::take(&mut *self.stream.lock());
+ Ok(Box::pin(RecordBatchStreamAdapter::new(
+ self.schema.clone(),
+ Box::pin(futures::stream::iter(stream).flatten()),
+ )))
+ }
+
+ fn statistics(&self) -> Result<Statistics> {
+ unimplemented!()
+ }
+}
diff --git a/native-engine/datafusion-ext-plans/src/common/mod.rs b/native-engine/datafusion-ext-plans/src/common/mod.rs
index c311a55..41436dd 100644
--- a/native-engine/datafusion-ext-plans/src/common/mod.rs
+++ b/native-engine/datafusion-ext-plans/src/common/mod.rs
@@ -17,5 +17,4 @@
pub mod cached_exprs_evaluator;
pub mod column_pruning;
pub mod ipc_compression;
-pub mod join_utils;
pub mod output;
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/existence_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/existence_join.rs
new file mode 100644
index 0000000..8053ee2
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/joins/bhj/existence_join.rs
@@ -0,0 +1,199 @@
+// Copyright 2022 The Blaze Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+ pin::Pin,
+ sync::{
+ atomic::{AtomicUsize, Ordering::Relaxed},
+ Arc,
+ },
+};
+
+use arrow::{
+ array::{BooleanArray, RecordBatch},
+ buffer::NullBuffer,
+ datatypes::Schema,
+ row::Rows,
+};
+use async_trait::async_trait;
+use bitvec::{bitvec, prelude::BitVec};
+use datafusion::{common::Result, physical_plan::metrics::Time};
+
+use crate::{
+ broadcast_join_exec::Joiner,
+ common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender},
+ joins::{
+ bhj::{
+ ProbeSide,
+ ProbeSide::{L, R},
+ },
+ join_hash_map::JoinHashMap,
+ Idx, JoinParams,
+ },
+};
+
+#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)]
+pub struct JoinerParams {
+ probe_side: ProbeSide,
+}
+
+impl JoinerParams {
+ const fn new(probe_side: ProbeSide) -> Self {
+ Self { probe_side }
+ }
+}
+
+const LEFT_PROBED: JoinerParams = JoinerParams::new(L);
+const RIGHT_PROBED: JoinerParams = JoinerParams::new(R);
+
+pub type LeftProbedExistenceJoiner = ExistenceJoiner<LEFT_PROBED>;
+pub type RightProbedExistenceJoiner = ExistenceJoiner<RIGHT_PROBED>;
+
+pub struct ExistenceJoiner<const P: JoinerParams> {
+ join_params: JoinParams,
+ output_sender: Arc<WrappedRecordBatchSender>,
+ map: JoinHashMap,
+ map_joined: Vec<BitVec>,
+ send_output_time: Time,
+ output_rows: AtomicUsize,
+}
+
+impl<const P: JoinerParams> ExistenceJoiner<P> {
+ pub fn new(
+ join_params: JoinParams,
+ map: JoinHashMap,
+ output_sender: Arc<WrappedRecordBatchSender>,
+ ) -> Self {
+ let map_joined = map
+ .data_batches()
+ .iter()
+ .map(|batch| bitvec![0; batch.num_rows()])
+ .collect();
+
+ Self {
+ join_params,
+ output_sender,
+ map,
+ map_joined,
+ send_output_time: Time::new(),
+ output_rows: AtomicUsize::new(0),
+ }
+ }
+
+ async fn flush(
+ &self,
+ probed_batch: RecordBatch,
+ build_indices: Vec<Idx>,
+ exists: Vec<bool>,
+ ) -> Result<()> {
+ let cols = match P.probe_side {
+ L => probed_batch,
+ R => interleave_batches(
+ self.map.data_schema(),
+ self.map.data_batches(),
+ &build_indices,
+ )?,
+ };
+ let exists_col = Arc::new(BooleanArray::from(exists));
+
+ let output_batch = RecordBatch::try_new(
+ self.join_params.output_schema.clone(),
+ [cols.columns().to_vec(), vec![exists_col]].concat(),
+ )?;
+ self.output_rows.fetch_add(output_batch.num_rows(), Relaxed);
+
+ let timer = self.send_output_time.timer();
+ self.output_sender.send(Ok(output_batch), None).await;
+ drop(timer);
+ Ok(())
+ }
+}
+
+#[async_trait]
+impl<const P: JoinerParams> Joiner for ExistenceJoiner<P> {
+ async fn join(
+ mut self: Pin<&mut Self>,
+ probed_batch: RecordBatch,
+ probed_key: Rows,
+ probed_key_has_null: Option<NullBuffer>,
+ ) -> Result<()> {
+ let mut exists = vec![];
+
+ for (row_idx, key) in probed_key.iter().enumerate() {
+ if !probed_key_has_null
+ .as_ref()
+ .map(|nb| nb.is_null(row_idx))
+ .unwrap_or(false)
+ {
+ match P.probe_side {
+ L => {
+ exists.push(self.map.search(key).next().is_some());
+ if exists.len() >= self.join_params.batch_size {
+ let exists = std::mem::take(&mut exists);
+ self.as_mut()
+ .flush(probed_batch.clone(), vec![], exists)
+ .await?;
+ }
+ }
+ R => {
+ for idx in self.map.search(key) {
+ // safety: bypass mutability checker
+ let map_joined = unsafe {
+ &mut *(&self.map_joined[idx.0] as *const BitVec as *mut BitVec)
+ };
+ map_joined.set(idx.1, true);
+ }
+ }
+ }
+ }
+ }
+ if !exists.is_empty() {
+ self.flush(probed_batch.clone(), vec![], exists).await?;
+ }
+ Ok(())
+ }
+
+ async fn finish(mut self: Pin<&mut Self>) -> Result<()> {
+ if P.probe_side == R {
+ let probed_empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
+ let map_joined = std::mem::take(&mut self.map_joined);
+
+ for (batch_idx, batch_joined) in map_joined.into_iter().enumerate() {
+ let mut build_indices: Vec<Idx> = Vec::with_capacity(batch_joined.len());
+ let mut exists = vec![];
+
+ for (row_idx, joined) in batch_joined.into_iter().enumerate() {
+ if self.map.inserted(batch_idx, row_idx) {
+ build_indices.push((batch_idx, row_idx));
+ exists.push(joined);
+ }
+ }
+ if !build_indices.is_empty() {
+ self.as_mut()
+ .flush(probed_empty_batch.clone(), build_indices, exists)
+ .await?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn total_send_output_time(&self) -> usize {
+ self.send_output_time.value()
+ }
+
+ fn num_output_rows(&self) -> usize {
+ self.output_rows.load(Relaxed)
+ }
+}
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs
new file mode 100644
index 0000000..a0fedce
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs
@@ -0,0 +1,233 @@
+// Copyright 2022 The Blaze Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+ pin::Pin,
+ sync::{
+ atomic::{AtomicUsize, Ordering::Relaxed},
+ Arc,
+ },
+};
+
+use arrow::{array::RecordBatch, buffer::NullBuffer, datatypes::Schema, row::Rows};
+use async_trait::async_trait;
+use bitvec::{bitvec, prelude::BitVec};
+use datafusion::{common::Result, physical_plan::metrics::Time};
+
+use crate::{
+ broadcast_join_exec::Joiner,
+ common::{
+ batch_selection::{interleave_batches, take_batch_opt},
+ output::WrappedRecordBatchSender,
+ },
+ joins::{
+ bhj::{
+ full_join::ProbeSide::{L, R},
+ ProbeSide,
+ },
+ join_hash_map::JoinHashMap,
+ Idx, JoinParams,
+ },
+};
+
+#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)]
+pub struct JoinerParams {
+ probe_side: ProbeSide,
+ probe_side_outer: bool,
+ build_side_outer: bool,
+}
+
+impl JoinerParams {
+ const fn new(probe_side: ProbeSide, probe_side_outer: bool, build_side_outer: bool) -> Self {
+ Self {
+ probe_side,
+ probe_side_outer,
+ build_side_outer,
+ }
+ }
+}
+
+const LEFT_PROBED_INNER: JoinerParams = JoinerParams::new(L, false, false);
+const LEFT_PROBED_LEFT: JoinerParams = JoinerParams::new(L, true, false);
+const LEFT_PROBED_RIGHT: JoinerParams = JoinerParams::new(L, false, true);
+const LEFT_PROBED_OUTER: JoinerParams = JoinerParams::new(L, true, true);
+
+const RIGHT_PROBED_INNER: JoinerParams = JoinerParams::new(R, false, false);
+const RIGHT_PROBED_LEFT: JoinerParams = JoinerParams::new(R, false, true);
+const RIGHT_PROBED_RIGHT: JoinerParams = JoinerParams::new(R, true, false);
+const RIGHT_PROBED_OUTER: JoinerParams = JoinerParams::new(R, true, true);
+
+pub type LeftProbedInnerJoiner = FullJoiner<LEFT_PROBED_INNER>;
+pub type LeftProbedLeftJoiner = FullJoiner<LEFT_PROBED_LEFT>;
+pub type LeftProbedRightJoiner = FullJoiner<LEFT_PROBED_RIGHT>;
+pub type LeftProbedFullOuterJoiner = FullJoiner<LEFT_PROBED_OUTER>;
+pub type RightProbedInnerJoiner = FullJoiner<RIGHT_PROBED_INNER>;
+pub type RightProbedLeftJoiner = FullJoiner<RIGHT_PROBED_LEFT>;
+pub type RightProbedRightJoiner = FullJoiner<RIGHT_PROBED_RIGHT>;
+pub type RightProbedFullOuterJoiner = FullJoiner<RIGHT_PROBED_OUTER>;
+
+pub struct FullJoiner<const P: JoinerParams> {
+ join_params: JoinParams,
+ output_sender: Arc<WrappedRecordBatchSender>,
+ map: JoinHashMap,
+ map_joined: Vec<BitVec>,
+ send_output_time: Time,
+ output_rows: AtomicUsize,
+}
+
+impl<const P: JoinerParams> FullJoiner<P> {
+ pub fn new(
+ join_params: JoinParams,
+ map: JoinHashMap,
+ output_sender: Arc<WrappedRecordBatchSender>,
+ ) -> Self {
+ let map_joined = map
+ .data_batches()
+ .iter()
+ .map(|batch| bitvec![0; batch.num_rows()])
+ .collect();
+
+ Self {
+ join_params,
+ output_sender,
+ map,
+ map_joined,
+ send_output_time: Time::default(),
+ output_rows: AtomicUsize::new(0),
+ }
+ }
+
+ async fn flush(
+ &self,
+ probed_batch: RecordBatch,
+ probe_indices: Vec<Option<u32>>,
+ build_indices: Vec<Idx>,
+ ) -> Result<()> {
+ let pcols = take_batch_opt(probed_batch.clone(), probe_indices)?;
+ let bcols = interleave_batches(
+ self.map.data_schema(),
+ self.map.data_batches(),
+ &build_indices,
+ )?;
+ let output_batch = RecordBatch::try_new(
+ self.join_params.output_schema.clone(),
+ match P.probe_side {
+ L => [pcols.columns().to_vec(), bcols.columns().to_vec()].concat(),
+ R => [bcols.columns().to_vec(), pcols.columns().to_vec()].concat(),
+ },
+ )?;
+ self.output_rows.fetch_add(output_batch.num_rows(), Relaxed);
+
+ let timer = self.send_output_time.timer();
+ self.output_sender.send(Ok(output_batch), None).await;
+ drop(timer);
+ Ok(())
+ }
+}
+
+#[async_trait]
+impl<const P: JoinerParams> Joiner for FullJoiner<P> {
+ async fn join(
+ mut self: Pin<&mut Self>,
+ probed_batch: RecordBatch,
+ probed_key: Rows,
+ probed_key_has_null: Option<NullBuffer>,
+ ) -> Result<()> {
+ let mut probe_indices: Vec<Option<u32>> = vec![];
+ let mut build_indices: Vec<Idx> = vec![];
+
+ for (row_idx, row) in probed_key.iter().enumerate() {
+ let mut found = false;
+ if !probed_key_has_null
+ .as_ref()
+ .map(|nb| nb.is_null(row_idx))
+ .unwrap_or(false)
+ {
+ for idx in self.map.search(&row) {
+ found = true;
+ probe_indices.push(Some(row_idx as u32));
+ build_indices.push(idx);
+
+ if P.build_side_outer {
+ // safety: bypass mutability checker
+ let map_joined = unsafe {
+ &mut *(&self.map_joined[idx.0] as *const BitVec as *mut BitVec)
+ };
+ map_joined.set(idx.1, true);
+ }
+ }
+ }
+ if P.probe_side_outer && !found {
+ probe_indices.push(Some(row_idx as u32));
+ build_indices.push(Idx::default());
+ }
+
+ if probe_indices.len() >= self.join_params.batch_size {
+ let probe_indices = std::mem::take(&mut probe_indices);
+ let build_indices = std::mem::take(&mut build_indices);
+ self.as_mut()
+ .flush(probed_batch.clone(), probe_indices, build_indices)
+ .await?;
+ }
+ }
+ if !probe_indices.is_empty() {
+ self.flush(probed_batch.clone(), probe_indices, build_indices)
+ .await?;
+ }
+ Ok(())
+ }
+
+ async fn finish(mut self: Pin<&mut Self>) -> Result<()> {
+ if P.build_side_outer {
+ let probed_schema = match P.probe_side {
+ L => self.join_params.left_schema.clone(),
+ R => self.join_params.right_schema.clone(),
+ };
+ let probed_empty_batch = RecordBatch::new_empty(Arc::new(Schema::new(
+ probed_schema
+ .fields()
+ .iter()
+ .map(|field| field.as_ref().clone().with_nullable(true))
+ .collect::<Vec<_>>(),
+ )));
+ let map_joined = std::mem::take(&mut self.map_joined);
+
+ for (batch_idx, batch_joined) in map_joined.into_iter().enumerate() {
+ let mut probe_indices: Vec<Option<u32>> = Vec::with_capacity(batch_joined.len());
+ let mut build_indices: Vec<Idx> = Vec::with_capacity(batch_joined.len());
+
+ for (row_idx, joined) in batch_joined.into_iter().enumerate() {
+ if !joined && self.map.inserted(batch_idx, row_idx) {
+ probe_indices.push(None);
+ build_indices.push((batch_idx, row_idx));
+ }
+ }
+ if !probe_indices.is_empty() {
+ self.as_mut()
+ .flush(probed_empty_batch.clone(), probe_indices, build_indices)
+ .await?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn total_send_output_time(&self) -> usize {
+ self.send_output_time.value()
+ }
+
+ fn num_output_rows(&self) -> usize {
+ self.output_rows.load(Relaxed)
+ }
+}
diff --git a/native-engine/datafusion-ext-plans/src/smj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs
similarity index 86%
copy from native-engine/datafusion-ext-plans/src/smj/mod.rs
copy to native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs
index 6ee6586..2adcd64 100644
--- a/native-engine/datafusion-ext-plans/src/smj/mod.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs
@@ -15,4 +15,9 @@
pub mod existence_join;
pub mod full_join;
pub mod semi_join;
-pub mod stream_cursor;
+
+#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)]
+pub enum ProbeSide {
+ L,
+ R,
+}
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
new file mode 100644
index 0000000..764f4b1
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
@@ -0,0 +1,230 @@
+// Copyright 2022 The Blaze Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+ pin::Pin,
+ sync::{
+ atomic::{AtomicUsize, Ordering::Relaxed},
+ Arc,
+ },
+};
+
+use arrow::{
+ array::{RecordBatch, RecordBatchOptions},
+ buffer::NullBuffer,
+ datatypes::Schema,
+ row::Rows,
+};
+use async_trait::async_trait;
+use bitvec::{bitvec, prelude::BitVec};
+use datafusion::{common::Result, physical_plan::metrics::Time};
+
+use crate::{
+ broadcast_join_exec::Joiner,
+ common::{
+ batch_selection::{interleave_batches, take_batch},
+ output::WrappedRecordBatchSender,
+ },
+ joins::{
+ bhj::{
+ semi_join::ProbeSide::{L, R},
+ ProbeSide,
+ },
+ join_hash_map::JoinHashMap,
+ Idx, JoinParams,
+ },
+};
+
+#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)]
+pub struct JoinerParams {
+ probe_side: ProbeSide,
+ probe_is_join_side: bool,
+ semi: bool,
+}
+
+impl JoinerParams {
+ const fn new(probe_side: ProbeSide, probe_is_join_side: bool, semi: bool) -> Self {
+ Self {
+ probe_side,
+ probe_is_join_side,
+ semi,
+ }
+ }
+}
+
+const LEFT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(L, true, true);
+const LEFT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(L, true, false);
+const LEFT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(L, false, true);
+const LEFT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(L, false, false);
+const RIGHT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(R, false, true);
+const RIGHT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(R, false, false);
+const RIGHT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true, true);
+const RIGHT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(R, true, false);
+
+pub type LeftProbedLeftSemiJoiner = SemiJoiner<LEFT_PROBED_LEFT_SEMI>;
+pub type LeftProbedLeftAntiJoiner = SemiJoiner<LEFT_PROBED_LEFT_ANTI>;
+pub type LeftProbedRightSemiJoiner = SemiJoiner<LEFT_PROBED_RIGHT_SEMI>;
+pub type LeftProbedRightAntiJoiner = SemiJoiner<LEFT_PROBED_RIGHT_ANTI>;
+pub type RightProbedLeftSemiJoiner = SemiJoiner<RIGHT_PROBED_LEFT_SEMI>;
+pub type RightProbedLeftAntiJoiner = SemiJoiner<RIGHT_PROBED_LEFT_ANTI>;
+pub type RightProbedRightSemiJoiner = SemiJoiner<RIGHT_PROBED_RIGHT_SEMI>;
+pub type RightProbedRightAntiJoiner = SemiJoiner<RIGHT_PROBED_RIGHT_ANTI>;
+
+pub struct SemiJoiner<const P: JoinerParams> {
+ join_params: JoinParams,
+ output_sender: Arc<WrappedRecordBatchSender>,
+ map_joined: Vec<BitVec>,
+ map: JoinHashMap,
+ send_output_time: Time,
+ output_rows: AtomicUsize,
+}
+
+impl<const P: JoinerParams> SemiJoiner<P> {
+ pub fn new(
+ join_params: JoinParams,
+ map: JoinHashMap,
+ output_sender: Arc<WrappedRecordBatchSender>,
+ ) -> Self {
+ let map_joined = map
+ .data_batches()
+ .iter()
+ .map(|batch| bitvec![0; batch.num_rows()])
+ .collect();
+
+ Self {
+ join_params,
+ output_sender,
+ map,
+ map_joined,
+ send_output_time: Time::new(),
+ output_rows: AtomicUsize::new(0),
+ }
+ }
+
+ async fn flush(
+ &self,
+ probed_batch: RecordBatch,
+ probe_indices: Vec<u32>,
+ build_indices: Vec<Idx>,
+ ) -> Result<()> {
+ let num_rows;
+ let cols = if P.probe_is_join_side {
+ num_rows = probe_indices.len();
+ take_batch(probed_batch.clone(), probe_indices)?
+ } else {
+ num_rows = build_indices.len();
+ interleave_batches(
+ self.map.data_schema(),
+ self.map.data_batches(),
+ &build_indices,
+ )?
+ };
+ let output_batch = RecordBatch::try_new_with_options(
+ self.join_params.output_schema.clone(),
+ cols.columns().to_vec(),
+ &RecordBatchOptions::new().with_row_count(Some(num_rows)),
+ )?;
+ self.output_rows.fetch_add(output_batch.num_rows(), Relaxed);
+
+ let timer = self.send_output_time.timer();
+ self.output_sender.send(Ok(output_batch), None).await;
+ drop(timer);
+ Ok(())
+ }
+}
+
+#[async_trait]
+impl<const P: JoinerParams> Joiner for SemiJoiner<P> {
+ async fn join(
+ mut self: Pin<&mut Self>,
+ probed_batch: RecordBatch,
+ probed_key: Rows,
+ probed_key_has_null: Option<NullBuffer>,
+ ) -> Result<()> {
+ let mut probe_indices: Vec<u32> = vec![];
+
+ for (row_idx, row) in probed_key.iter().enumerate() {
+ let mut found = false;
+ if !probed_key_has_null
+ .as_ref()
+ .map(|nb| nb.is_null(row_idx))
+ .unwrap_or(false)
+ {
+ for idx in self.map.search(&row) {
+ found = true;
+ if P.probe_is_join_side {
+ if P.semi {
+ probe_indices.push(row_idx as u32);
+ }
+ break;
+ } else {
+ // safety: bypass mutability checker
+ let map_joined = unsafe {
+ &mut *(&self.map_joined[idx.0] as *const BitVec as *mut BitVec)
+ };
+ map_joined.set(idx.1, true);
+ }
+ }
+ }
+ if P.probe_is_join_side && !P.semi && !found {
+ probe_indices.push(row_idx as u32);
+ }
+
+ if probe_indices.len() >= self.join_params.batch_size {
+ let probe_indices = std::mem::take(&mut probe_indices);
+ self.as_mut()
+ .flush(probed_batch.clone(), probe_indices, vec![])
+ .await?;
+ }
+ }
+ if !probe_indices.is_empty() {
+ self.flush(probed_batch.clone(), probe_indices, vec![])
+ .await?;
+ }
+ Ok(())
+ }
+
+ async fn finish(mut self: Pin<&mut Self>) -> Result<()> {
+ if !P.probe_is_join_side {
+ let probed_empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
+ let map_joined = std::mem::take(&mut self.map_joined);
+
+ for (batch_idx, batch_joined) in map_joined.into_iter().enumerate() {
+ let mut build_indices: Vec<Idx> = Vec::with_capacity(batch_joined.len());
+
+ for (row_idx, joined) in batch_joined.into_iter().enumerate() {
+ if P.semi && joined
+ || !P.semi && !joined && self.map.inserted(batch_idx, row_idx)
+ {
+ build_indices.push((batch_idx, row_idx));
+ }
+ }
+ if !build_indices.is_empty() {
+ self.as_mut()
+ .flush(probed_empty_batch.clone(), vec![], build_indices)
+ .await?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn total_send_output_time(&self) -> usize {
+ self.send_output_time.value()
+ }
+
+ fn num_output_rows(&self) -> usize {
+ self.output_rows.load(Relaxed)
+ }
+}
diff --git a/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs
new file mode 100644
index 0000000..5657db1
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs
@@ -0,0 +1,362 @@
+// Copyright 2022 The Blaze Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{hash::Hasher, sync::Arc};
+
+use arrow::{
+ array::{
+ make_array, Array, ArrayRef, AsArray, BinaryArray, BinaryBuilder, MutableArrayData,
+ RecordBatch, RecordBatchOptions,
+ },
+ compute::filter_record_batch,
+ datatypes::{DataType, Field, FieldRef, Schema, SchemaRef},
+ row::Rows,
+};
+use datafusion::common::Result;
+use datafusion_ext_commons::{batch_size, df_execution_err};
+use gxhash::GxHasher;
+use itertools::Itertools;
+use num::Integer;
+use once_cell::sync::OnceCell;
+
+use crate::joins::Idx;
+
+pub struct JoinHashMap {
+ batches: Vec<RecordBatch>, // data batch + key column
+ keys: Vec<BinaryArray>,
+ data_batches: Vec<RecordBatch>, // batches excluding the last key column
+ batch_size: usize, // all batches except the last one must have the same size
+ num_entries: usize,
+}
+
+impl JoinHashMap {
+ pub fn try_new(hash_map_batches: Vec<RecordBatch>) -> Result<Self> {
+ if hash_map_batches.is_empty() {
+ return df_execution_err!("JoinHashMap should have at least one entry");
+ }
+ let batch_size = hash_map_batches[0].num_rows();
+
+ if hash_map_batches[..hash_map_batches.len() - 1]
+ .iter()
+ .any(|batch| batch.num_rows() != batch_size)
+ {
+ return df_execution_err!("JoinHashMap expect batch size {batch_size}");
+ }
+ Ok(Self {
+ num_entries: hash_map_batches.iter().map(|batch| batch.num_rows()).sum(),
+ keys: join_keys_from_hash_map_batches(&hash_map_batches),
+ data_batches: hash_map_batches
+ .iter()
+ .cloned()
+ .map(|mut batch| {
+ batch.remove_column(batch.num_columns() - 1);
+ batch
+ })
+ .collect(),
+ batches: hash_map_batches,
+ batch_size,
+ })
+ }
+
+ pub fn data_schema(&self) -> SchemaRef {
+ self.data_batches()[0].schema()
+ }
+
+ pub fn data_batches(&self) -> &[RecordBatch] {
+ &self.data_batches
+ }
+
+ pub fn hash_map_batches(&self) -> &[RecordBatch] {
+ &self.batches
+ }
+
+ pub fn into_hash_map_batches(self) -> Vec<RecordBatch> {
+ self.batches
+ }
+
+ pub fn inserted(&self, batch_idx: usize, row_idx: usize) -> bool {
+ self.keys[batch_idx].is_valid(row_idx)
+ }
+
+ pub fn search<'a, K: AsRef<[u8]> + 'a>(&'a self, key: K) -> impl Iterator<Item = Idx> + 'a {
+ struct EntryIterator<'a, K: AsRef<[u8]> + 'a> {
+ key: K,
+ join_hash_map_keys: &'a [BinaryArray],
+ batch_size: usize,
+ num_entries: usize,
+ entry: usize,
+ }
+ impl<'a, K: AsRef<[u8]> + 'a> Iterator for EntryIterator<'a, K> {
+ type Item = Idx;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ loop {
+ // get current idx and advance to next entry
+ let idx = self.entry.div_rem(&self.batch_size);
+ self.entry = 1 + ((self.entry + 1) % (self.num_entries - 1));
+
+ // check current entry
+ let keys = &self.join_hash_map_keys[idx.0];
+ if keys.is_null(idx.1) {
+ return None;
+ }
+ if keys.value(idx.1) == self.key.as_ref() {
+ return Some(idx);
+ }
+ }
+ }
+ }
+
+ let entry = 1 + ((join_hash(key.as_ref()) as usize) % (self.num_entries - 1));
+ EntryIterator {
+ key,
+ join_hash_map_keys: &self.keys,
+ batch_size: self.batch_size,
+ num_entries: self.num_entries,
+ entry,
+ }
+ }
+
+ pub fn iter_values(&self) -> impl Iterator<Item = RecordBatch> {
+ self.batches
+ .clone()
+ .into_iter()
+ .filter_map(|mut batch| {
+ let keys = batch.remove_column(batch.num_columns() - 1);
+ let filtered = arrow::compute::is_not_null(&keys)
+ .and_then(|valid_array| Ok(filter_record_batch(&batch, &valid_array)?))
+ .expect("error filtering record batch");
+ let non_empty = filtered.num_rows() > 0;
+ non_empty.then(|| filtered)
+ })
+ .filter(|batch| batch.num_rows() > 0)
+ }
+}
+
+pub fn build_join_hash_map(
+ batches: &[RecordBatch],
+ batch_schema: SchemaRef,
+ keys: &[Rows],
+) -> Result<JoinHashMap> {
+ let batch_size = batch_size();
+ let num_valid_entries = batches.iter().map(|batch| batch.num_rows()).sum::<usize>();
+ let num_entries = (num_valid_entries + 16) * 5 + 1;
+
+ // build idx_map
+ let mut idx_map = vec![None; num_entries];
+ for (batch_idx, keys) in keys.iter().enumerate() {
+ for (row_idx, key) in keys.iter().enumerate() {
+ let hash = join_hash(key) as usize;
+ let mut entry = 1 + (hash % (num_entries - 1)); // 0 is reserved
+
+ // find an empty slot, then insert the key index
+ while idx_map[entry].is_some() {
+ entry = 1 + ((entry + 1) % (num_entries - 1)); // 0 is reserved
+ }
+ idx_map[entry] = Some((batch_idx, row_idx));
+ }
+ }
+
+ // build hash map batches
+ let hash_map_schema = join_hash_map_schema(&batch_schema);
+ let mut hash_map_batches = vec![];
+ let mut batch_array_datas = vec![vec![]; batch_schema.fields().len()];
+ for batch in batches {
+ for (col_idx, col) in batch.columns().iter().enumerate() {
+ batch_array_datas[col_idx].push(col.to_data());
+ }
+ }
+ for batch_entries in idx_map.iter().chunks(batch_size).into_iter() {
+ let mut keys_builder = BinaryBuilder::with_capacity(batch_size, 0);
+ let mut data_cols_builder: Vec<MutableArrayData> = batch_array_datas
+ .iter()
+ .map(|col_array_datas| {
+ MutableArrayData::new(col_array_datas.iter().collect(), true, batch_size)
+ })
+ .collect();
+
+ for entry in batch_entries {
+ if let &Some((batch_idx, row_idx)) = entry {
+ keys_builder.append_value(keys[batch_idx].row(row_idx).as_ref());
+ for col_builder in data_cols_builder.iter_mut() {
+ col_builder.extend(batch_idx, row_idx, row_idx + 1);
+ }
+ } else {
+ keys_builder.append_null();
+ for col_builder in data_cols_builder.iter_mut() {
+ col_builder.extend_nulls(1);
+ }
+ }
+ }
+ let keys: ArrayRef = Arc::new(keys_builder.finish());
+ let data_cols: Vec<ArrayRef> = data_cols_builder
+ .into_iter()
+ .map(|col_builder| make_array(col_builder.freeze()))
+ .collect();
+ let batch_num_rows = keys.len();
+ let hash_map_batch = RecordBatch::try_new_with_options(
+ hash_map_schema.clone(),
+ [data_cols, vec![keys]].concat(),
+ &RecordBatchOptions::new().with_row_count(Some(batch_num_rows)),
+ )?;
+ hash_map_batches.push(hash_map_batch);
+ }
+ Ok(JoinHashMap {
+ keys: join_keys_from_hash_map_batches(&hash_map_batches),
+ data_batches: hash_map_batches
+ .iter()
+ .cloned()
+ .map(|mut batch| {
+ batch.remove_column(batch.num_columns() - 1);
+ batch
+ })
+ .collect(),
+ batches: hash_map_batches,
+ batch_size,
+ num_entries,
+ })
+}
+
+#[inline]
+pub fn join_hash_map_schema(data_schema: &SchemaRef) -> SchemaRef {
+ Arc::new(Schema::new(
+ data_schema
+ .fields()
+ .iter()
+ .map(|field| Arc::new(field.as_ref().clone().with_nullable(true)))
+ .chain(std::iter::once(join_key_field()))
+ .collect::<Vec<_>>(),
+ ))
+}
+
+#[inline]
+fn join_key_field() -> FieldRef {
+ static BHJ_KEY_FIELD: OnceCell<FieldRef> = OnceCell::new();
+ BHJ_KEY_FIELD
+ .get_or_init(|| Arc::new(Field::new("~KEY", DataType::Binary, true)))
+ .clone()
+}
+
+#[inline]
+fn join_hash(key: impl AsRef<[u8]>) -> u32 {
+ let mut h = GxHasher::with_seed(0x10c736ed99f9c14e);
+ h.write(key.as_ref());
+ h.finish() as u32
+}
+
+#[inline]
+fn join_keys_from_hash_map_batches(hash_map_batches: &[RecordBatch]) -> Vec<BinaryArray> {
+ hash_map_batches
+ .iter()
+ .map(|batch| {
+ batch
+ .columns()
+ .last()
+ .expect("hashmap key column not found")
+ .as_binary()
+ .clone()
+ })
+ .collect()
+}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Arc;
+
+ use arrow::{
+ array::{AsArray, Int32Array, RecordBatch, StringArray},
+ datatypes::{DataType, Field, Int32Type, Schema},
+ row::{RowConverter, SortField},
+ };
+ use datafusion::{assert_batches_sorted_eq, common::Result};
+
+ #[test]
+ fn test_join_hash_map() -> Result<()> {
+ // generate a string key-value record batch for testing
+ let batch_schema = Arc::new(Schema::new(vec![
+ Field::new("key", DataType::Utf8, false),
+ Field::new("value", DataType::Int32, false),
+ ]));
+ let batch = RecordBatch::try_new(
+ batch_schema.clone(),
+ vec![
+ Arc::new(StringArray::from(vec![
+ "a0", "a111", "a222", "a333", "a333", "a444",
+ ])),
+ Arc::new(Int32Array::from(vec![0, 111, 222, 3331, 3332, 444])),
+ ],
+ )?;
+
+ // generate hashmap
+ let row_converter = RowConverter::new(vec![SortField::new(DataType::Utf8)])?;
+ let keys = row_converter.convert_columns(&[batch.column(0).clone()])?;
+ let key333 = keys.row(3).owned();
+ let hashmap = super::build_join_hash_map(&[batch], batch_schema, &[keys])?;
+
+ // test searching a333
+ let mut iter = hashmap.search(key333);
+ let idx = iter.next().unwrap();
+ let hash_map_batch = &hashmap.hash_map_batches()[idx.0];
+ assert_eq!(
+ hash_map_batch.column(0).as_string::<i32>().value(idx.1),
+ "a333"
+ );
+ assert_eq!(
+ hash_map_batch
+ .column(1)
+ .as_primitive::<Int32Type>()
+ .value(idx.1),
+ 3331
+ );
+ let idx = iter.next().unwrap();
+ let hash_map_batch = &hashmap.hash_map_batches()[idx.0];
+ assert_eq!(
+ hash_map_batch.column(0).as_string::<i32>().value(idx.1),
+ "a333"
+ );
+ assert_eq!(
+ hash_map_batch
+ .column(1)
+ .as_primitive::<Int32Type>()
+ .value(idx.1),
+ 3332
+ );
+ let idx = iter.next();
+ assert_eq!(idx, None);
+
+ // test searching inexistent key
+ let mut iter = hashmap.search(b"inexistent");
+ let idx = iter.next();
+ assert_eq!(idx, None);
+
+ // test iter values
+ let value_batch = hashmap.iter_values().next().unwrap();
+ assert_batches_sorted_eq!(
+ vec![
+ "+------+-------+",
+ "| key | value |",
+ "+------+-------+",
+ "| a0 | 0 |",
+ "| a111 | 111 |",
+ "| a222 | 222 |",
+ "| a333 | 3331 |",
+ "| a333 | 3332 |",
+ "| a444 | 444 |",
+ "+------+-------+",
+ ],
+ &[value_batch]
+ );
+ Ok(())
+ }
+}
diff --git a/native-engine/datafusion-ext-plans/src/common/join_utils.rs b/native-engine/datafusion-ext-plans/src/joins/join_utils.rs
similarity index 100%
rename from native-engine/datafusion-ext-plans/src/common/join_utils.rs
rename to native-engine/datafusion-ext-plans/src/joins/join_utils.rs
diff --git a/native-engine/datafusion-ext-plans/src/joins/mod.rs b/native-engine/datafusion-ext-plans/src/joins/mod.rs
new file mode 100644
index 0000000..243fc60
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/joins/mod.rs
@@ -0,0 +1,46 @@
+// Copyright 2022 The Blaze Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use arrow::{
+ compute::SortOptions,
+ datatypes::{DataType, SchemaRef},
+};
+use datafusion::physical_expr::PhysicalExprRef;
+
+use crate::joins::{join_utils::JoinType, stream_cursor::StreamCursor};
+
+pub mod join_hash_map;
+pub mod join_utils;
+pub mod stream_cursor;
+
+// join implementations
+pub mod bhj;
+pub mod smj;
+mod test;
+
+#[derive(Debug, Clone)]
+pub struct JoinParams {
+ pub join_type: JoinType,
+ pub left_schema: SchemaRef,
+ pub right_schema: SchemaRef,
+ pub output_schema: SchemaRef,
+ pub left_keys: Vec<PhysicalExprRef>,
+ pub right_keys: Vec<PhysicalExprRef>,
+ pub key_data_types: Vec<DataType>,
+ pub sort_options: Vec<SortOptions>,
+ pub batch_size: usize,
+}
+
+pub type Idx = (usize, usize);
+pub type StreamCursors = (StreamCursor, StreamCursor);
diff --git a/native-engine/datafusion-ext-plans/src/smj/existence_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs
similarity index 95%
rename from native-engine/datafusion-ext-plans/src/smj/existence_join.rs
rename to native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs
index 405ebe9..5a4324e 100644
--- a/native-engine/datafusion-ext-plans/src/smj/existence_join.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs
@@ -22,7 +22,8 @@
use crate::{
common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender},
compare_cursor, cur_forward,
- sort_merge_join_exec::{Idx, JoinParams, Joiner, StreamCursors},
+ joins::{Idx, JoinParams, StreamCursors},
+ sort_merge_join_exec::Joiner,
};
pub struct ExistenceJoiner {
@@ -31,6 +32,7 @@
indices: Vec<Idx>,
exists: Vec<bool>,
send_output_time: Time,
+ output_rows: usize,
}
impl ExistenceJoiner {
@@ -41,6 +43,7 @@
indices: vec![],
exists: vec![],
send_output_time: Time::new(),
+ output_rows: 0,
}
}
@@ -76,6 +79,8 @@
)?;
if output_batch.num_rows() > 0 {
+ self.output_rows += output_batch.num_rows();
+
let timer = self.send_output_time.timer();
self.output_sender.send(Ok(output_batch), None).await;
drop(timer);
@@ -167,4 +172,8 @@
fn total_send_output_time(&self) -> usize {
self.send_output_time.value()
}
+
+ fn num_output_rows(&self) -> usize {
+ self.output_rows
+ }
}
diff --git a/native-engine/datafusion-ext-plans/src/smj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
similarity index 96%
rename from native-engine/datafusion-ext-plans/src/smj/full_join.rs
rename to native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
index de359a9..7baafd0 100644
--- a/native-engine/datafusion-ext-plans/src/smj/full_join.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
@@ -23,7 +23,8 @@
use crate::{
common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender},
compare_cursor, cur_forward,
- sort_merge_join_exec::{Idx, JoinParams, Joiner, StreamCursors},
+ joins::{Idx, JoinParams, StreamCursors},
+ sort_merge_join_exec::Joiner,
};
pub struct FullJoiner<const L_OUTER: bool, const R_OUTER: bool> {
@@ -32,6 +33,7 @@
lindices: Vec<Idx>,
rindices: Vec<Idx>,
send_output_time: Time,
+ output_rows: usize,
}
pub type InnerJoiner = FullJoiner<false, false>;
@@ -47,6 +49,7 @@
lindices: vec![],
rindices: vec![],
send_output_time: Time::new(),
+ output_rows: 0,
}
}
@@ -87,6 +90,8 @@
)?;
if output_batch.num_rows() > 0 {
+ self.output_rows += output_batch.num_rows();
+
let timer = self.send_output_time.timer();
self.output_sender.send(Ok(output_batch), None).await;
drop(timer);
@@ -225,4 +230,8 @@
fn total_send_output_time(&self) -> usize {
self.send_output_time.value()
}
+
+ fn num_output_rows(&self) -> usize {
+ self.output_rows
+ }
}
diff --git a/native-engine/datafusion-ext-plans/src/smj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs
similarity index 96%
rename from native-engine/datafusion-ext-plans/src/smj/mod.rs
rename to native-engine/datafusion-ext-plans/src/joins/smj/mod.rs
index 6ee6586..8bcdadf 100644
--- a/native-engine/datafusion-ext-plans/src/smj/mod.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs
@@ -15,4 +15,3 @@
pub mod existence_join;
pub mod full_join;
pub mod semi_join;
-pub mod stream_cursor;
diff --git a/native-engine/datafusion-ext-plans/src/smj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs
similarity index 66%
rename from native-engine/datafusion-ext-plans/src/smj/semi_join.rs
rename to native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs
index 195693f..fbd60cc 100644
--- a/native-engine/datafusion-ext-plans/src/smj/semi_join.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs
@@ -22,35 +22,56 @@
use crate::{
common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender},
compare_cursor, cur_forward,
- sort_merge_join_exec::{Idx, JoinParams, Joiner, StreamCursors},
+ joins::{
+ smj::semi_join::SemiJoinSide::{L, R},
+ Idx, JoinParams, StreamCursors,
+ },
+ sort_merge_join_exec::Joiner,
};
-pub struct SemiJoiner<
- const L_SEMI: bool,
- const R_SEMI: bool,
- const L_ANTI: bool,
- const R_ANTI: bool,
-> {
+#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)]
+pub enum SemiJoinSide {
+ L,
+ R,
+}
+
+#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)]
+pub struct JoinerParams {
+ join_side: SemiJoinSide,
+ semi: bool,
+}
+
+impl JoinerParams {
+ const fn new(join_side: SemiJoinSide, semi: bool) -> Self {
+ Self { join_side, semi }
+ }
+}
+pub struct SemiJoiner<const P: JoinerParams> {
join_params: JoinParams,
output_sender: Arc<WrappedRecordBatchSender>,
indices: Vec<Idx>,
send_output_time: Time,
+ output_rows: usize,
}
-pub type LeftSemiJoiner = SemiJoiner<true, false, false, false>;
-pub type RightSemiJoiner = SemiJoiner<false, true, false, false>;
-pub type LeftAntiJoiner = SemiJoiner<false, false, true, false>;
-pub type RightAntiJoiner = SemiJoiner<false, false, false, true>;
+const LEFT_SEMI: JoinerParams = JoinerParams::new(L, true);
+const LEFT_ANTI: JoinerParams = JoinerParams::new(L, false);
+const RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true);
+const RIGHT_ANTI: JoinerParams = JoinerParams::new(R, false);
-impl<const L_SEMI: bool, const R_SEMI: bool, const L_ANTI: bool, const R_ANTI: bool>
- SemiJoiner<L_SEMI, R_SEMI, L_ANTI, R_ANTI>
-{
+pub type LeftSemiJoiner = SemiJoiner<LEFT_SEMI>;
+pub type LeftAntiJoiner = SemiJoiner<LEFT_ANTI>;
+pub type RightSemiJoiner = SemiJoiner<RIGHT_SEMI>;
+pub type RightAntiJoiner = SemiJoiner<RIGHT_ANTI>;
+
+impl<const P: JoinerParams> SemiJoiner<P> {
pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self {
Self {
join_params,
output_sender,
indices: vec![],
send_output_time: Time::new(),
+ output_rows: 0,
}
}
@@ -63,10 +84,9 @@
&& curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size()
{
if let Some(first_idx) = self.indices.first() {
- let cur_idx = if L_SEMI || L_ANTI {
- curs.0.cur_idx
- } else {
- curs.1.cur_idx
+ let cur_idx = match P.join_side {
+ L => curs.0.cur_idx,
+ R => curs.1.cur_idx,
};
if first_idx.0 < cur_idx.0 {
return true;
@@ -79,10 +99,9 @@
async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
let indices = std::mem::take(&mut self.indices);
let num_rows = indices.len();
- let cols = if L_SEMI || L_ANTI {
- interleave_batches(curs.0.batch_schema.clone(), &curs.0.batches, &indices)?
- } else {
- interleave_batches(curs.1.batch_schema.clone(), &curs.1.batches, &indices)?
+ let cols = match P.join_side {
+ L => interleave_batches(curs.0.batch_schema.clone(), &curs.0.batches, &indices)?,
+ R => interleave_batches(curs.1.batch_schema.clone(), &curs.1.batches, &indices)?,
};
let output_batch = RecordBatch::try_new_with_options(
self.join_params.output_schema.clone(),
@@ -91,6 +110,8 @@
)?;
if output_batch.num_rows() > 0 {
+ self.output_rows += output_batch.num_rows();
+
let timer = self.send_output_time.timer();
self.output_sender.send(Ok(output_batch), None).await;
drop(timer);
@@ -100,37 +121,33 @@
async fn join_less(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
let lidx = curs.0.cur_idx;
- if L_ANTI {
+ if P.join_side == L && !P.semi {
self.indices.push(lidx);
}
cur_forward!(curs.0);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
- if L_SEMI || L_ANTI {
- curs.0
- .set_min_reserved_idx(*self.indices.first().unwrap_or(&lidx));
- } else {
- curs.0.set_min_reserved_idx(lidx);
- }
+ curs.0.set_min_reserved_idx(match P.join_side {
+ L => *self.indices.first().unwrap_or(&lidx),
+ R => lidx,
+ });
Ok(())
}
async fn join_greater(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
let ridx = curs.1.cur_idx;
- if R_ANTI {
+ if P.join_side == R && !P.semi {
self.indices.push(ridx);
}
cur_forward!(curs.1);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
- if R_SEMI || R_ANTI {
- curs.1
- .set_min_reserved_idx(*self.indices.first().unwrap_or(&ridx));
- } else {
- curs.1.set_min_reserved_idx(ridx);
- }
+ curs.1.set_min_reserved_idx(match P.join_side {
+ L => ridx,
+ R => *self.indices.first().unwrap_or(&ridx),
+ });
Ok(())
}
@@ -140,17 +157,16 @@
// output/skip left equal rows
loop {
- if L_SEMI {
+ if P.join_side == L && P.semi {
self.indices.push(lidx);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
}
cur_forward!(curs.0);
- curs.0.set_min_reserved_idx(*if L_SEMI || L_ANTI {
- self.indices.first().unwrap_or(&lidx)
- } else {
- &lidx
+ curs.0.set_min_reserved_idx(match P.join_side {
+ L => *self.indices.first().unwrap_or(&lidx),
+ R => lidx,
});
if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) {
@@ -162,17 +178,16 @@
// output/skip right equal rows
loop {
- if R_SEMI {
+ if P.join_side == R && P.semi {
self.indices.push(ridx);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
}
cur_forward!(curs.1);
- curs.1.set_min_reserved_idx(*if R_SEMI || R_ANTI {
- self.indices.first().unwrap_or(&ridx)
- } else {
- &ridx
+ curs.1.set_min_reserved_idx(match P.join_side {
+ L => ridx,
+ R => *self.indices.first().unwrap_or(&ridx),
});
if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) {
@@ -186,9 +201,7 @@
}
#[async_trait]
-impl<const L_SEMI: bool, const R_SEMI: bool, const L_ANTI: bool, const R_ANTI: bool> Joiner
- for SemiJoiner<L_SEMI, R_SEMI, L_ANTI, R_ANTI>
-{
+impl<const P: JoinerParams> Joiner for SemiJoiner<P> {
async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
while !curs.0.finished && !curs.1.finished {
match compare_cursor!(curs) {
@@ -205,11 +218,13 @@
}
// at least one side is finished, consume the other side if it is an anti side
- while L_ANTI && !curs.0.finished {
- self.as_mut().join_less(curs).await?;
- }
- while R_ANTI && !curs.1.finished {
- self.as_mut().join_greater(curs).await?;
+ if !P.semi {
+ while P.join_side == L && !curs.0.finished {
+ self.as_mut().join_less(curs).await?;
+ }
+ while P.join_side == R && !curs.1.finished {
+ self.as_mut().join_greater(curs).await?;
+ }
}
if !self.indices.is_empty() {
self.flush(curs).await?;
@@ -220,4 +235,8 @@
fn total_send_output_time(&self) -> usize {
self.send_output_time.value()
}
+
+ fn num_output_rows(&self) -> usize {
+ self.output_rows
+ }
}
diff --git a/native-engine/datafusion-ext-plans/src/smj/stream_cursor.rs b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs
similarity index 99%
rename from native-engine/datafusion-ext-plans/src/smj/stream_cursor.rs
rename to native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs
index f8d24c6..81a1317 100644
--- a/native-engine/datafusion-ext-plans/src/smj/stream_cursor.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs
@@ -32,7 +32,7 @@
use crate::{
common::batch_selection::take_batch_opt,
- sort_merge_join_exec::{Idx, JoinParams},
+ joins::{Idx, JoinParams},
};
pub struct StreamCursor {
diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs b/native-engine/datafusion-ext-plans/src/joins/test.rs
new file mode 100644
index 0000000..c6a8874
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/joins/test.rs
@@ -0,0 +1,945 @@
+// Copyright 2022 The Blaze Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Arc;
+
+ use arrow::{
+ self,
+ array::*,
+ compute::SortOptions,
+ datatypes::{DataType, Field, Schema, SchemaRef},
+ record_batch::RecordBatch,
+ };
+ use datafusion::{
+ assert_batches_sorted_eq,
+ common::JoinSide,
+ error::Result,
+ physical_expr::expressions::Column,
+ physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan},
+ prelude::SessionContext,
+ };
+ use TestType::*;
+
+ use crate::{
+ broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec,
+ broadcast_join_exec::BroadcastJoinExec,
+ joins::join_utils::{JoinType, JoinType::*},
+ sort_merge_join_exec::SortMergeJoinExec,
+ };
+
+ #[derive(Clone, Copy)]
+ enum TestType {
+ SMJ,
+ BHJLeftProbed,
+ BHJRightProbed,
+ }
+
+ fn columns(schema: &Schema) -> Vec<String> {
+ schema.fields().iter().map(|f| f.name().clone()).collect()
+ }
+
+ fn build_table_i32(
+ a: (&str, &Vec<i32>),
+ b: (&str, &Vec<i32>),
+ c: (&str, &Vec<i32>),
+ ) -> RecordBatch {
+ let schema = Schema::new(vec![
+ Field::new(a.0, DataType::Int32, false),
+ Field::new(b.0, DataType::Int32, false),
+ Field::new(c.0, DataType::Int32, false),
+ ]);
+
+ RecordBatch::try_new(
+ Arc::new(schema),
+ vec![
+ Arc::new(Int32Array::from(a.1.clone())),
+ Arc::new(Int32Array::from(b.1.clone())),
+ Arc::new(Int32Array::from(c.1.clone())),
+ ],
+ )
+ .unwrap()
+ }
+
+ fn build_table(
+ a: (&str, &Vec<i32>),
+ b: (&str, &Vec<i32>),
+ c: (&str, &Vec<i32>),
+ ) -> Arc<dyn ExecutionPlan> {
+ let batch = build_table_i32(a, b, c);
+ let schema = batch.schema();
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+ }
+
+ fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
+ let schema = batches.first().unwrap().schema();
+ Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap())
+ }
+
+ fn build_date_table(
+ a: (&str, &Vec<i32>),
+ b: (&str, &Vec<i32>),
+ c: (&str, &Vec<i32>),
+ ) -> Arc<dyn ExecutionPlan> {
+ let schema = Schema::new(vec![
+ Field::new(a.0, DataType::Date32, false),
+ Field::new(b.0, DataType::Date32, false),
+ Field::new(c.0, DataType::Date32, false),
+ ]);
+
+ let batch = RecordBatch::try_new(
+ Arc::new(schema),
+ vec![
+ Arc::new(Date32Array::from(a.1.clone())),
+ Arc::new(Date32Array::from(b.1.clone())),
+ Arc::new(Date32Array::from(c.1.clone())),
+ ],
+ )
+ .unwrap();
+
+ let schema = batch.schema();
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+ }
+
+ fn build_date64_table(
+ a: (&str, &Vec<i64>),
+ b: (&str, &Vec<i64>),
+ c: (&str, &Vec<i64>),
+ ) -> Arc<dyn ExecutionPlan> {
+ let schema = Schema::new(vec![
+ Field::new(a.0, DataType::Date64, false),
+ Field::new(b.0, DataType::Date64, false),
+ Field::new(c.0, DataType::Date64, false),
+ ]);
+
+ let batch = RecordBatch::try_new(
+ Arc::new(schema),
+ vec![
+ Arc::new(Date64Array::from(a.1.clone())),
+ Arc::new(Date64Array::from(b.1.clone())),
+ Arc::new(Date64Array::from(c.1.clone())),
+ ],
+ )
+ .unwrap();
+
+ let schema = batch.schema();
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+ }
+
+ /// returns a table with 3 columns of i32 in memory
+ pub fn build_table_i32_nullable(
+ a: (&str, &Vec<Option<i32>>),
+ b: (&str, &Vec<Option<i32>>),
+ c: (&str, &Vec<Option<i32>>),
+ ) -> Arc<dyn ExecutionPlan> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new(a.0, DataType::Int32, true),
+ Field::new(b.0, DataType::Int32, true),
+ Field::new(c.0, DataType::Int32, true),
+ ]));
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Int32Array::from(a.1.clone())),
+ Arc::new(Int32Array::from(b.1.clone())),
+ Arc::new(Int32Array::from(c.1.clone())),
+ ],
+ )
+ .unwrap();
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+ }
+
+ fn build_join_schema_for_test(
+ left: &Schema,
+ right: &Schema,
+ join_type: JoinType,
+ ) -> Result<SchemaRef> {
+ if join_type == Existence {
+ let exists_field = Arc::new(Field::new("exists#0", DataType::Boolean, false));
+ return Ok(Arc::new(Schema::new(
+ [left.fields().to_vec(), vec![exists_field]].concat(),
+ )));
+ }
+ Ok(Arc::new(
+ build_join_schema(left, right, &join_type.try_into()?).0,
+ ))
+ }
+
+ async fn join_collect(
+ test_type: TestType,
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ join_type: JoinType,
+ ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?;
+
+ let join: Arc<dyn ExecutionPlan> = match test_type {
+ SMJ => {
+ let sort_options = vec![SortOptions::default(); on.len()];
+ Arc::new(SortMergeJoinExec::try_new(
+ schema,
+ left,
+ right,
+ on,
+ join_type,
+ sort_options,
+ )?)
+ }
+ BHJLeftProbed => {
+ let right = Arc::new(BroadcastJoinBuildHashMapExec::new(
+ right,
+ on.iter().map(|(_, right_key)| right_key.clone()).collect(),
+ ));
+ Arc::new(BroadcastJoinExec::try_new(
+ schema,
+ left,
+ right,
+ on,
+ join_type,
+ JoinSide::Right,
+ )?)
+ }
+ BHJRightProbed => {
+ let left = Arc::new(BroadcastJoinBuildHashMapExec::new(
+ left,
+ on.iter().map(|(left_key, _)| left_key.clone()).collect(),
+ ));
+ Arc::new(BroadcastJoinExec::try_new(
+ schema,
+ left,
+ right,
+ on,
+ join_type,
+ JoinSide::Left,
+ )?)
+ }
+ };
+ let columns = columns(&join.schema());
+ let stream = join.execute(0, task_ctx)?;
+ let batches = common::collect(stream).await?;
+ Ok((columns, batches))
+ }
+
+ #[tokio::test]
+ async fn join_inner_one() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 5]), // this has a repetition
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b1", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Inner).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 5 | 9 | 20 | 5 | 80 |",
+ "+----+----+----+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_inner_two() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2]),
+ ("b2", &vec![1, 2, 2]),
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b2", &vec![1, 2, 2]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?),
+ Arc::new(Column::new_with_schema("a1", &right.schema())?),
+ ),
+ (
+ Arc::new(Column::new_with_schema("b2", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ ),
+ ];
+
+ let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b2 | c1 | a1 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 1 | 7 | 1 | 1 | 70 |",
+ "| 2 | 2 | 8 | 2 | 2 | 80 |",
+ "| 2 | 2 | 9 | 2 | 2 | 80 |",
+ "+----+----+----+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_inner_two_two() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 1, 2]),
+ ("b2", &vec![1, 1, 2]),
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a1", &vec![1, 1, 3]),
+ ("b2", &vec![1, 1, 2]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?),
+ Arc::new(Column::new_with_schema("a1", &right.schema())?),
+ ),
+ (
+ Arc::new(Column::new_with_schema("b2", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ ),
+ ];
+
+ let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b2 | c1 | a1 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 1 | 7 | 1 | 1 | 70 |",
+ "| 1 | 1 | 7 | 1 | 1 | 80 |",
+ "| 1 | 1 | 8 | 1 | 1 | 70 |",
+ "| 1 | 1 | 8 | 1 | 1 | 80 |",
+ "+----+----+----+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_inner_with_nulls() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table_i32_nullable(
+ ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]),
+ ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field
+ ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field
+ );
+ let right = build_table_i32_nullable(
+ ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]),
+ ("b2", &vec![None, Some(1), Some(2), Some(2)]),
+ ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]),
+ );
+ let on: JoinOn = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?),
+ Arc::new(Column::new_with_schema("a1", &right.schema())?),
+ ),
+ (
+ Arc::new(Column::new_with_schema("b2", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ ),
+ ];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Inner).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b2 | c1 | a1 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 1 | | 1 | 1 | 70 |",
+ "| 2 | 2 | 8 | 2 | 2 | 80 |",
+ "| 2 | 2 | 9 | 2 | 2 | 80 |",
+ "+----+----+----+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_left_one() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b1", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Left).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 7 | 9 | | | |",
+ "+----+----+----+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_one() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 7]),
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b1", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Right).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| | | | 30 | 6 | 90 |",
+ "+----+----+----+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_full_one() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b2", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Full).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| | | | 30 | 6 | 90 |",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 7 | 9 | | | |",
+ "+----+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_anti() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2, 3, 5]),
+ ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 8, 9, 11]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b1", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, LeftAnti).await?;
+ let expected = vec![
+ "+----+----+----+",
+ "| a1 | b1 | c1 |",
+ "+----+----+----+",
+ "| 3 | 7 | 9 |",
+ "| 5 | 7 | 11 |",
+ "+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_semi() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2, 3]),
+ ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]), // 5 is double on the right
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b1", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, LeftSemi).await?;
+ let expected = vec![
+ "+----+----+----+",
+ "| a1 | b1 | c1 |",
+ "+----+----+----+",
+ "| 1 | 4 | 7 |",
+ "| 2 | 5 | 8 |",
+ "| 2 | 5 | 8 |",
+ "+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_with_duplicated_column_names() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a", &vec![1, 2, 3]),
+ ("b", &vec![4, 5, 7]),
+ ("c", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a", &vec![10, 20, 30]),
+ ("b", &vec![1, 2, 7]),
+ ("c", &vec![70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ // join on a=b so there are duplicate column names on unjoined columns
+ Arc::new(Column::new_with_schema("a", &left.schema())?),
+ Arc::new(Column::new_with_schema("b", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Inner).await?;
+ let expected = vec![
+ "+---+---+---+----+---+----+",
+ "| a | b | c | a | b | c |",
+ "+---+---+---+----+---+----+",
+ "| 1 | 4 | 7 | 10 | 1 | 70 |",
+ "| 2 | 5 | 8 | 20 | 2 | 80 |",
+ "+---+---+---+----+---+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_date32() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_date_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![19107, 19108, 19108]), // this has a repetition
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_date_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![19107, 19108, 19109]),
+ ("c2", &vec![70, 80, 90]),
+ );
+
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b1", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Inner).await?;
+
+ let expected = vec![
+ "+------------+------------+------------+------------+------------+------------+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+------------+------------+------------+------------+------------+------------+",
+ "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |",
+ "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
+ "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
+ "+------------+------------+------------+------------+------------+------------+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_date64() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_date64_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), /* this has a
+ * repetition */
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_date64_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![1650703441000, 1650503441000, 1650903441000]),
+ ("c2", &vec![70, 80, 90]),
+ );
+
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b1", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Inner).await?;
+ let expected = vec![
+ "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
+ "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |",
+ "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
+ "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
+ "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
+ ];
+
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_left_sort_order() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![0, 1, 2, 3, 4, 5]),
+ ("b1", &vec![3, 4, 5, 6, 6, 7]),
+ ("c1", &vec![4, 5, 6, 7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![0, 10, 20, 30, 40]),
+ ("b2", &vec![2, 4, 6, 6, 8]),
+ ("c2", &vec![50, 60, 70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Left).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 0 | 3 | 4 | | | |",
+ "| 1 | 4 | 5 | 10 | 4 | 60 |",
+ "| 2 | 5 | 6 | | | |",
+ "| 3 | 6 | 7 | 20 | 6 | 70 |",
+ "| 3 | 6 | 7 | 30 | 6 | 80 |",
+ "| 4 | 6 | 8 | 20 | 6 | 70 |",
+ "| 4 | 6 | 8 | 30 | 6 | 80 |",
+ "| 5 | 7 | 9 | | | |",
+ "+----+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_sort_order() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left = build_table(
+ ("a1", &vec![0, 1, 2, 3]),
+ ("b1", &vec![3, 4, 5, 7]),
+ ("c1", &vec![6, 7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![0, 10, 20, 30]),
+ ("b2", &vec![2, 4, 5, 6]),
+ ("c2", &vec![60, 70, 80, 90]),
+ );
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Right).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| | | | 0 | 2 | 60 |",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| | | | 30 | 6 | 90 |",
+ "+----+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_left_multiple_batches() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left_batch_1 = build_table_i32(
+ ("a1", &vec![0, 1, 2]),
+ ("b1", &vec![3, 4, 5]),
+ ("c1", &vec![4, 5, 6]),
+ );
+ let left_batch_2 = build_table_i32(
+ ("a1", &vec![3, 4, 5, 6]),
+ ("b1", &vec![6, 6, 7, 9]),
+ ("c1", &vec![7, 8, 9, 9]),
+ );
+ let right_batch_1 = build_table_i32(
+ ("a2", &vec![0, 10, 20]),
+ ("b2", &vec![2, 4, 6]),
+ ("c2", &vec![50, 60, 70]),
+ );
+ let right_batch_2 = build_table_i32(
+ ("a2", &vec![30, 40]),
+ ("b2", &vec![6, 8]),
+ ("c2", &vec![80, 90]),
+ );
+ let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
+ let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Left).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 0 | 3 | 4 | | | |",
+ "| 1 | 4 | 5 | 10 | 4 | 60 |",
+ "| 2 | 5 | 6 | | | |",
+ "| 3 | 6 | 7 | 20 | 6 | 70 |",
+ "| 3 | 6 | 7 | 30 | 6 | 80 |",
+ "| 4 | 6 | 8 | 20 | 6 | 70 |",
+ "| 4 | 6 | 8 | 30 | 6 | 80 |",
+ "| 5 | 7 | 9 | | | |",
+ "| 6 | 9 | 9 | | | |",
+ "+----+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_multiple_batches() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let right_batch_1 = build_table_i32(
+ ("a2", &vec![0, 1, 2]),
+ ("b2", &vec![3, 4, 5]),
+ ("c2", &vec![4, 5, 6]),
+ );
+ let right_batch_2 = build_table_i32(
+ ("a2", &vec![3, 4, 5, 6]),
+ ("b2", &vec![6, 6, 7, 9]),
+ ("c2", &vec![7, 8, 9, 9]),
+ );
+ let left_batch_1 = build_table_i32(
+ ("a1", &vec![0, 10, 20]),
+ ("b1", &vec![2, 4, 6]),
+ ("c1", &vec![50, 60, 70]),
+ );
+ let left_batch_2 = build_table_i32(
+ ("a1", &vec![30, 40]),
+ ("b1", &vec![6, 8]),
+ ("c1", &vec![80, 90]),
+ );
+ let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
+ let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Right).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| | | | 0 | 3 | 4 |",
+ "| 10 | 4 | 60 | 1 | 4 | 5 |",
+ "| | | | 2 | 5 | 6 |",
+ "| 20 | 6 | 70 | 3 | 6 | 7 |",
+ "| 30 | 6 | 80 | 3 | 6 | 7 |",
+ "| 20 | 6 | 70 | 4 | 6 | 8 |",
+ "| 30 | 6 | 80 | 4 | 6 | 8 |",
+ "| | | | 5 | 7 | 9 |",
+ "| | | | 6 | 9 | 9 |",
+ "+----+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_full_multiple_batches() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left_batch_1 = build_table_i32(
+ ("a1", &vec![0, 1, 2]),
+ ("b1", &vec![3, 4, 5]),
+ ("c1", &vec![4, 5, 6]),
+ );
+ let left_batch_2 = build_table_i32(
+ ("a1", &vec![3, 4, 5, 6]),
+ ("b1", &vec![6, 6, 7, 9]),
+ ("c1", &vec![7, 8, 9, 9]),
+ );
+ let right_batch_1 = build_table_i32(
+ ("a2", &vec![0, 10, 20]),
+ ("b2", &vec![2, 4, 6]),
+ ("c2", &vec![50, 60, 70]),
+ );
+ let right_batch_2 = build_table_i32(
+ ("a2", &vec![30, 40]),
+ ("b2", &vec![6, 8]),
+ ("c2", &vec![80, 90]),
+ );
+ let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
+ let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Full).await?;
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| | | | 0 | 2 | 50 |",
+ "| | | | 40 | 8 | 90 |",
+ "| 0 | 3 | 4 | | | |",
+ "| 1 | 4 | 5 | 10 | 4 | 60 |",
+ "| 2 | 5 | 6 | | | |",
+ "| 3 | 6 | 7 | 20 | 6 | 70 |",
+ "| 3 | 6 | 7 | 30 | 6 | 80 |",
+ "| 4 | 6 | 8 | 20 | 6 | 70 |",
+ "| 4 | 6 | 8 | 30 | 6 | 80 |",
+ "| 5 | 7 | 9 | | | |",
+ "| 6 | 9 | 9 | | | |",
+ "+----+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_existence_multiple_batches() -> Result<()> {
+ for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] {
+ let left_batch_1 = build_table_i32(
+ ("a1", &vec![0, 1, 2]),
+ ("b1", &vec![3, 4, 5]),
+ ("c1", &vec![4, 5, 6]),
+ );
+ let left_batch_2 = build_table_i32(
+ ("a1", &vec![3, 4, 5, 6]),
+ ("b1", &vec![6, 6, 7, 9]),
+ ("c1", &vec![7, 8, 9, 9]),
+ );
+ let right_batch_1 = build_table_i32(
+ ("a2", &vec![0, 10, 20]),
+ ("b2", &vec![2, 4, 6]),
+ ("c2", &vec![50, 60, 70]),
+ );
+ let right_batch_2 = build_table_i32(
+ ("a2", &vec![30, 40]),
+ ("b2", &vec![6, 8]),
+ ("c2", &vec![80, 90]),
+ );
+ let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
+ let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
+ let on: JoinOn = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?),
+ Arc::new(Column::new_with_schema("b2", &right.schema())?),
+ )];
+
+ let (_, batches) = join_collect(test_type, left, right, on, Existence).await?;
+ let expected = vec![
+ "+----+----+----+----------+",
+ "| a1 | b1 | c1 | exists#0 |",
+ "+----+----+----+----------+",
+ "| 0 | 3 | 4 | false |",
+ "| 1 | 4 | 5 | true |",
+ "| 2 | 5 | 6 | false |",
+ "| 3 | 6 | 7 | true |",
+ "| 4 | 6 | 8 | true |",
+ "| 5 | 7 | 9 | false |",
+ "| 6 | 9 | 9 | false |",
+ "+----+----+----+----------+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+ }
+ Ok(())
+ }
+}
diff --git a/native-engine/datafusion-ext-plans/src/lib.rs b/native-engine/datafusion-ext-plans/src/lib.rs
index 253a9ea..744cbb4 100644
--- a/native-engine/datafusion-ext-plans/src/lib.rs
+++ b/native-engine/datafusion-ext-plans/src/lib.rs
@@ -14,23 +14,22 @@
#![feature(get_mut_unchecked)]
#![feature(io_error_other)]
+#![feature(adt_const_params)]
-pub mod agg;
+// execution plan implementations
pub mod agg_exec;
+pub mod broadcast_join_build_hash_map_exec;
pub mod broadcast_join_exec;
pub mod broadcast_nested_loop_join_exec;
-pub mod common;
pub mod debug_exec;
pub mod empty_partitions_exec;
pub mod expand_exec;
pub mod ffi_reader_exec;
pub mod filter_exec;
-pub mod generate;
pub mod generate_exec;
pub mod ipc_reader_exec;
pub mod ipc_writer_exec;
pub mod limit_exec;
-pub mod memmgr;
pub mod parquet_exec;
pub mod parquet_sink_exec;
pub mod project_exec;
@@ -39,8 +38,15 @@
pub mod shuffle_writer_exec;
pub mod sort_exec;
pub mod sort_merge_join_exec;
-pub mod window;
pub mod window_exec;
+// memory management
+pub mod memmgr;
+
+// helper modules
+pub mod agg;
+pub mod common;
+pub mod generate;
+pub mod joins;
mod shuffle;
-mod smj;
+pub mod window;
diff --git a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs
index f2dff1d..69b46cf 100644
--- a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs
+++ b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs
@@ -35,7 +35,6 @@
SendableRecordBatchStream, Statistics,
},
};
-use datafusion_ext_commons::df_execution_err;
use futures::{Stream, StreamExt};
use crate::agg::AGG_BUF_COLUMN_NAME;
@@ -56,7 +55,12 @@
let input_schema = input.schema();
let mut new_names = vec![];
- for (i, field) in input_schema.fields().iter().enumerate() {
+ for (i, field) in input_schema
+ .fields()
+ .iter()
+ .take(renamed_column_names.len())
+ .enumerate()
+ {
if field.name() != AGG_BUF_COLUMN_NAME {
new_names.push(renamed_column_names[i].clone());
} else {
@@ -64,11 +68,9 @@
break;
}
}
- if new_names.len() != input_schema.fields().len() {
- df_execution_err!(
- "renamed_column_names length not matched with input schema, \
- renames: {renamed_column_names:?}, input schema: {input_schema}",
- )?;
+
+ while new_names.len() < input_schema.fields().len() {
+ new_names.push(input_schema.field(new_names.len()).name().clone());
}
let renamed_column_names = new_names;
let renamed_schema = Arc::new(Schema::new(
diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
index 6e3e3b2..508731d 100644
--- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
+++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
@@ -20,10 +20,7 @@
time::{Duration, Instant},
};
-use arrow::{
- compute::SortOptions,
- datatypes::{DataType, SchemaRef},
-};
+use arrow::{compute::SortOptions, datatypes::SchemaRef};
use async_trait::async_trait;
use datafusion::{
common::JoinSide,
@@ -44,16 +41,17 @@
use futures::TryStreamExt;
use crate::{
- common::{
- join_utils::{JoinType, JoinType::*},
- output::{TaskOutputter, WrappedRecordBatchSender},
- },
+ common::output::{TaskOutputter, WrappedRecordBatchSender},
cur_forward,
- smj::{
- existence_join::ExistenceJoiner,
- full_join::{FullOuterJoiner, InnerJoiner, LeftOuterJoiner, RightOuterJoiner},
- semi_join::{LeftAntiJoiner, LeftSemiJoiner, RightAntiJoiner, RightSemiJoiner},
+ joins::{
+ join_utils::{JoinType, JoinType::*},
+ smj::{
+ existence_join::ExistenceJoiner,
+ full_join::{FullOuterJoiner, InnerJoiner, LeftOuterJoiner, RightOuterJoiner},
+ semi_join::{LeftAntiJoiner, LeftSemiJoiner, RightAntiJoiner, RightSemiJoiner},
+ },
stream_cursor::StreamCursor,
+ JoinParams, StreamCursors,
},
};
@@ -117,6 +115,8 @@
// overflowing
Ok(JoinParams {
join_type: self.join_type,
+ left_schema,
+ right_schema,
output_schema: self.schema(),
left_keys,
right_keys,
@@ -211,17 +211,6 @@
}
}
-#[derive(Clone)]
-pub struct JoinParams {
- pub join_type: JoinType,
- pub output_schema: SchemaRef,
- pub left_keys: Vec<PhysicalExprRef>,
- pub right_keys: Vec<PhysicalExprRef>,
- pub key_data_types: Vec<DataType>,
- pub sort_options: Vec<SortOptions>,
- pub batch_size: usize,
-}
-
pub async fn execute_join(
lstream: SendableRecordBatchStream,
rstream: SendableRecordBatchStream,
@@ -251,6 +240,7 @@
Existence => Box::pin(ExistenceJoiner::new(join_params, sender)),
};
joiner.as_mut().join(&mut curs).await?;
+ metrics.record_output(joiner.num_output_rows());
// discount poll input and send output batch time
let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64;
@@ -263,9 +253,6 @@
Ok(())
}
-pub type Idx = (usize, usize);
-pub type StreamCursors = (StreamCursor, StreamCursor);
-
#[macro_export]
macro_rules! compare_cursor {
($curs:expr) => {{
@@ -281,848 +268,5 @@
pub trait Joiner {
async fn join(self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()>;
fn total_send_output_time(&self) -> usize;
-}
-
-#[cfg(test)]
-mod tests {
- use std::sync::Arc;
-
- use arrow::{
- self,
- array::*,
- compute::SortOptions,
- datatypes::{DataType, Field, Schema, SchemaRef},
- record_batch::RecordBatch,
- };
- use datafusion::{
- assert_batches_sorted_eq,
- error::Result,
- physical_expr::expressions::Column,
- physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan},
- prelude::SessionContext,
- };
-
- use crate::{
- common::join_utils::{JoinType, JoinType::*},
- sort_merge_join_exec::SortMergeJoinExec,
- };
-
- fn columns(schema: &Schema) -> Vec<String> {
- schema.fields().iter().map(|f| f.name().clone()).collect()
- }
-
- fn build_table_i32(
- a: (&str, &Vec<i32>),
- b: (&str, &Vec<i32>),
- c: (&str, &Vec<i32>),
- ) -> RecordBatch {
- let schema = Schema::new(vec![
- Field::new(a.0, DataType::Int32, false),
- Field::new(b.0, DataType::Int32, false),
- Field::new(c.0, DataType::Int32, false),
- ]);
-
- RecordBatch::try_new(
- Arc::new(schema),
- vec![
- Arc::new(Int32Array::from(a.1.clone())),
- Arc::new(Int32Array::from(b.1.clone())),
- Arc::new(Int32Array::from(c.1.clone())),
- ],
- )
- .unwrap()
- }
-
- fn build_table(
- a: (&str, &Vec<i32>),
- b: (&str, &Vec<i32>),
- c: (&str, &Vec<i32>),
- ) -> Arc<dyn ExecutionPlan> {
- let batch = build_table_i32(a, b, c);
- let schema = batch.schema();
- Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
- }
-
- fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
- let schema = batches.first().unwrap().schema();
- Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap())
- }
-
- fn build_date_table(
- a: (&str, &Vec<i32>),
- b: (&str, &Vec<i32>),
- c: (&str, &Vec<i32>),
- ) -> Arc<dyn ExecutionPlan> {
- let schema = Schema::new(vec![
- Field::new(a.0, DataType::Date32, false),
- Field::new(b.0, DataType::Date32, false),
- Field::new(c.0, DataType::Date32, false),
- ]);
-
- let batch = RecordBatch::try_new(
- Arc::new(schema),
- vec![
- Arc::new(Date32Array::from(a.1.clone())),
- Arc::new(Date32Array::from(b.1.clone())),
- Arc::new(Date32Array::from(c.1.clone())),
- ],
- )
- .unwrap();
-
- let schema = batch.schema();
- Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
- }
-
- fn build_date64_table(
- a: (&str, &Vec<i64>),
- b: (&str, &Vec<i64>),
- c: (&str, &Vec<i64>),
- ) -> Arc<dyn ExecutionPlan> {
- let schema = Schema::new(vec![
- Field::new(a.0, DataType::Date64, false),
- Field::new(b.0, DataType::Date64, false),
- Field::new(c.0, DataType::Date64, false),
- ]);
-
- let batch = RecordBatch::try_new(
- Arc::new(schema),
- vec![
- Arc::new(Date64Array::from(a.1.clone())),
- Arc::new(Date64Array::from(b.1.clone())),
- Arc::new(Date64Array::from(c.1.clone())),
- ],
- )
- .unwrap();
-
- let schema = batch.schema();
- Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
- }
-
- /// returns a table with 3 columns of i32 in memory
- pub fn build_table_i32_nullable(
- a: (&str, &Vec<Option<i32>>),
- b: (&str, &Vec<Option<i32>>),
- c: (&str, &Vec<Option<i32>>),
- ) -> Arc<dyn ExecutionPlan> {
- let schema = Arc::new(Schema::new(vec![
- Field::new(a.0, DataType::Int32, true),
- Field::new(b.0, DataType::Int32, true),
- Field::new(c.0, DataType::Int32, true),
- ]));
- let batch = RecordBatch::try_new(
- schema.clone(),
- vec![
- Arc::new(Int32Array::from(a.1.clone())),
- Arc::new(Int32Array::from(b.1.clone())),
- Arc::new(Int32Array::from(c.1.clone())),
- ],
- )
- .unwrap();
- Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
- }
-
- fn build_join_schema_for_test(
- left: &Schema,
- right: &Schema,
- join_type: JoinType,
- ) -> Result<SchemaRef> {
- if join_type == Existence {
- let exists_field = Arc::new(Field::new("exists#0", DataType::Boolean, false));
- return Ok(Arc::new(Schema::new(
- [left.fields().to_vec(), vec![exists_field]].concat(),
- )));
- }
- Ok(Arc::new(
- build_join_schema(left, right, &join_type.try_into()?).0,
- ))
- }
-
- async fn join_collect(
- left: Arc<dyn ExecutionPlan>,
- right: Arc<dyn ExecutionPlan>,
- on: JoinOn,
- join_type: JoinType,
- ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
- let session_ctx = SessionContext::new();
- let task_ctx = session_ctx.task_ctx();
- let sort_options = vec![SortOptions::default(); on.len()];
- let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?;
- let join = SortMergeJoinExec::try_new(schema, left, right, on, join_type, sort_options)?;
- let columns = columns(&join.schema());
- let stream = join.execute(0, task_ctx)?;
- let batches = common::collect(stream).await?;
- Ok((columns, batches))
- }
-
- #[tokio::test]
- async fn join_inner_one() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 2, 3]),
- ("b1", &vec![4, 5, 5]), // this has a repetition
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_table(
- ("a2", &vec![10, 20, 30]),
- ("b1", &vec![4, 5, 6]),
- ("c2", &vec![70, 80, 90]),
- );
-
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b1", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Inner).await?;
-
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b1 | c2 |",
- "+----+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 4 | 70 |",
- "| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| 3 | 5 | 9 | 20 | 5 | 80 |",
- "+----+----+----+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_inner_two() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 2, 2]),
- ("b2", &vec![1, 2, 2]),
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_table(
- ("a1", &vec![1, 2, 3]),
- ("b2", &vec![1, 2, 2]),
- ("c2", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![
- (
- Arc::new(Column::new_with_schema("a1", &left.schema())?),
- Arc::new(Column::new_with_schema("a1", &right.schema())?),
- ),
- (
- Arc::new(Column::new_with_schema("b2", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- ),
- ];
-
- let (_columns, batches) = join_collect(left, right, on, Inner).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b2 | c1 | a1 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| 1 | 1 | 7 | 1 | 1 | 70 |",
- "| 2 | 2 | 8 | 2 | 2 | 80 |",
- "| 2 | 2 | 9 | 2 | 2 | 80 |",
- "+----+----+----+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_inner_two_two() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 1, 2]),
- ("b2", &vec![1, 1, 2]),
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_table(
- ("a1", &vec![1, 1, 3]),
- ("b2", &vec![1, 1, 2]),
- ("c2", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![
- (
- Arc::new(Column::new_with_schema("a1", &left.schema())?),
- Arc::new(Column::new_with_schema("a1", &right.schema())?),
- ),
- (
- Arc::new(Column::new_with_schema("b2", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- ),
- ];
-
- let (_columns, batches) = join_collect(left, right, on, Inner).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b2 | c1 | a1 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| 1 | 1 | 7 | 1 | 1 | 70 |",
- "| 1 | 1 | 7 | 1 | 1 | 80 |",
- "| 1 | 1 | 8 | 1 | 1 | 70 |",
- "| 1 | 1 | 8 | 1 | 1 | 80 |",
- "+----+----+----+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_inner_with_nulls() -> Result<()> {
- let left = build_table_i32_nullable(
- ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]),
- ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field
- ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field
- );
- let right = build_table_i32_nullable(
- ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]),
- ("b2", &vec![None, Some(1), Some(2), Some(2)]),
- ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]),
- );
- let on: JoinOn = vec![
- (
- Arc::new(Column::new_with_schema("a1", &left.schema())?),
- Arc::new(Column::new_with_schema("a1", &right.schema())?),
- ),
- (
- Arc::new(Column::new_with_schema("b2", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- ),
- ];
-
- let (_, batches) = join_collect(left, right, on, Inner).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b2 | c1 | a1 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| 1 | 1 | | 1 | 1 | 70 |",
- "| 2 | 2 | 8 | 2 | 2 | 80 |",
- "| 2 | 2 | 9 | 2 | 2 | 80 |",
- "+----+----+----+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_left_one() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 2, 3]),
- ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_table(
- ("a2", &vec![10, 20, 30]),
- ("b1", &vec![4, 5, 6]),
- ("c2", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b1", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Left).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b1 | c2 |",
- "+----+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 4 | 70 |",
- "| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| 3 | 7 | 9 | | | |",
- "+----+----+----+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_right_one() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 2, 3]),
- ("b1", &vec![4, 5, 7]),
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_table(
- ("a2", &vec![10, 20, 30]),
- ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
- ("c2", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b1", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Right).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b1 | c2 |",
- "+----+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 4 | 70 |",
- "| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| | | | 30 | 6 | 90 |",
- "+----+----+----+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_full_one() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 2, 3]),
- ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_table(
- ("a2", &vec![10, 20, 30]),
- ("b2", &vec![4, 5, 6]),
- ("c2", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Full).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| | | | 30 | 6 | 90 |",
- "| 1 | 4 | 7 | 10 | 4 | 70 |",
- "| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| 3 | 7 | 9 | | | |",
- "+----+----+----+----+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_anti() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 2, 2, 3, 5]),
- ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
- ("c1", &vec![7, 8, 8, 9, 11]),
- );
- let right = build_table(
- ("a2", &vec![10, 20, 30]),
- ("b1", &vec![4, 5, 6]),
- ("c2", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b1", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, LeftAnti).await?;
- let expected = vec![
- "+----+----+----+",
- "| a1 | b1 | c1 |",
- "+----+----+----+",
- "| 3 | 7 | 9 |",
- "| 5 | 7 | 11 |",
- "+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_semi() -> Result<()> {
- let left = build_table(
- ("a1", &vec![1, 2, 2, 3]),
- ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
- ("c1", &vec![7, 8, 8, 9]),
- );
- let right = build_table(
- ("a2", &vec![10, 20, 30]),
- ("b1", &vec![4, 5, 6]), // 5 is double on the right
- ("c2", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b1", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, LeftSemi).await?;
- let expected = vec![
- "+----+----+----+",
- "| a1 | b1 | c1 |",
- "+----+----+----+",
- "| 1 | 4 | 7 |",
- "| 2 | 5 | 8 |",
- "| 2 | 5 | 8 |",
- "+----+----+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_with_duplicated_column_names() -> Result<()> {
- let left = build_table(
- ("a", &vec![1, 2, 3]),
- ("b", &vec![4, 5, 7]),
- ("c", &vec![7, 8, 9]),
- );
- let right = build_table(
- ("a", &vec![10, 20, 30]),
- ("b", &vec![1, 2, 7]),
- ("c", &vec![70, 80, 90]),
- );
- let on: JoinOn = vec![(
- // join on a=b so there are duplicate column names on unjoined columns
- Arc::new(Column::new_with_schema("a", &left.schema())?),
- Arc::new(Column::new_with_schema("b", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Inner).await?;
- let expected = vec![
- "+---+---+---+----+---+----+",
- "| a | b | c | a | b | c |",
- "+---+---+---+----+---+----+",
- "| 1 | 4 | 7 | 10 | 1 | 70 |",
- "| 2 | 5 | 8 | 20 | 2 | 80 |",
- "+---+---+---+----+---+----+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_date32() -> Result<()> {
- let left = build_date_table(
- ("a1", &vec![1, 2, 3]),
- ("b1", &vec![19107, 19108, 19108]), // this has a repetition
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_date_table(
- ("a2", &vec![10, 20, 30]),
- ("b1", &vec![19107, 19108, 19109]),
- ("c2", &vec![70, 80, 90]),
- );
-
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b1", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Inner).await?;
-
- let expected = vec![
- "+------------+------------+------------+------------+------------+------------+",
- "| a1 | b1 | c1 | a2 | b1 | c2 |",
- "+------------+------------+------------+------------+------------+------------+",
- "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |",
- "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
- "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
- "+------------+------------+------------+------------+------------+------------+",
- ];
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_date64() -> Result<()> {
- let left = build_date64_table(
- ("a1", &vec![1, 2, 3]),
- ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition
- ("c1", &vec![7, 8, 9]),
- );
- let right = build_date64_table(
- ("a2", &vec![10, 20, 30]),
- ("b1", &vec![1650703441000, 1650503441000, 1650903441000]),
- ("c2", &vec![70, 80, 90]),
- );
-
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b1", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Inner).await?;
- let expected = vec![
- "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
- "| a1 | b1 | c1 | a2 | b1 | c2 |",
- "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
- "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |",
- "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
- "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
- "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
- ];
-
- // The output order is important as SMJ preserves sortedness
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_left_sort_order() -> Result<()> {
- let left = build_table(
- ("a1", &vec![0, 1, 2, 3, 4, 5]),
- ("b1", &vec![3, 4, 5, 6, 6, 7]),
- ("c1", &vec![4, 5, 6, 7, 8, 9]),
- );
- let right = build_table(
- ("a2", &vec![0, 10, 20, 30, 40]),
- ("b2", &vec![2, 4, 6, 6, 8]),
- ("c2", &vec![50, 60, 70, 80, 90]),
- );
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Left).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| 0 | 3 | 4 | | | |",
- "| 1 | 4 | 5 | 10 | 4 | 60 |",
- "| 2 | 5 | 6 | | | |",
- "| 3 | 6 | 7 | 20 | 6 | 70 |",
- "| 3 | 6 | 7 | 30 | 6 | 80 |",
- "| 4 | 6 | 8 | 20 | 6 | 70 |",
- "| 4 | 6 | 8 | 30 | 6 | 80 |",
- "| 5 | 7 | 9 | | | |",
- "+----+----+----+----+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_right_sort_order() -> Result<()> {
- let left = build_table(
- ("a1", &vec![0, 1, 2, 3]),
- ("b1", &vec![3, 4, 5, 7]),
- ("c1", &vec![6, 7, 8, 9]),
- );
- let right = build_table(
- ("a2", &vec![0, 10, 20, 30]),
- ("b2", &vec![2, 4, 5, 6]),
- ("c2", &vec![60, 70, 80, 90]),
- );
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Right).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| | | | 0 | 2 | 60 |",
- "| 1 | 4 | 7 | 10 | 4 | 70 |",
- "| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| | | | 30 | 6 | 90 |",
- "+----+----+----+----+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_left_multiple_batches() -> Result<()> {
- let left_batch_1 = build_table_i32(
- ("a1", &vec![0, 1, 2]),
- ("b1", &vec![3, 4, 5]),
- ("c1", &vec![4, 5, 6]),
- );
- let left_batch_2 = build_table_i32(
- ("a1", &vec![3, 4, 5, 6]),
- ("b1", &vec![6, 6, 7, 9]),
- ("c1", &vec![7, 8, 9, 9]),
- );
- let right_batch_1 = build_table_i32(
- ("a2", &vec![0, 10, 20]),
- ("b2", &vec![2, 4, 6]),
- ("c2", &vec![50, 60, 70]),
- );
- let right_batch_2 = build_table_i32(
- ("a2", &vec![30, 40]),
- ("b2", &vec![6, 8]),
- ("c2", &vec![80, 90]),
- );
- let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
- let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Left).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| 0 | 3 | 4 | | | |",
- "| 1 | 4 | 5 | 10 | 4 | 60 |",
- "| 2 | 5 | 6 | | | |",
- "| 3 | 6 | 7 | 20 | 6 | 70 |",
- "| 3 | 6 | 7 | 30 | 6 | 80 |",
- "| 4 | 6 | 8 | 20 | 6 | 70 |",
- "| 4 | 6 | 8 | 30 | 6 | 80 |",
- "| 5 | 7 | 9 | | | |",
- "| 6 | 9 | 9 | | | |",
- "+----+----+----+----+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_right_multiple_batches() -> Result<()> {
- let right_batch_1 = build_table_i32(
- ("a2", &vec![0, 1, 2]),
- ("b2", &vec![3, 4, 5]),
- ("c2", &vec![4, 5, 6]),
- );
- let right_batch_2 = build_table_i32(
- ("a2", &vec![3, 4, 5, 6]),
- ("b2", &vec![6, 6, 7, 9]),
- ("c2", &vec![7, 8, 9, 9]),
- );
- let left_batch_1 = build_table_i32(
- ("a1", &vec![0, 10, 20]),
- ("b1", &vec![2, 4, 6]),
- ("c1", &vec![50, 60, 70]),
- );
- let left_batch_2 = build_table_i32(
- ("a1", &vec![30, 40]),
- ("b1", &vec![6, 8]),
- ("c1", &vec![80, 90]),
- );
- let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
- let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Right).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| | | | 0 | 3 | 4 |",
- "| 10 | 4 | 60 | 1 | 4 | 5 |",
- "| | | | 2 | 5 | 6 |",
- "| 20 | 6 | 70 | 3 | 6 | 7 |",
- "| 30 | 6 | 80 | 3 | 6 | 7 |",
- "| 20 | 6 | 70 | 4 | 6 | 8 |",
- "| 30 | 6 | 80 | 4 | 6 | 8 |",
- "| | | | 5 | 7 | 9 |",
- "| | | | 6 | 9 | 9 |",
- "+----+----+----+----+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_full_multiple_batches() -> Result<()> {
- let left_batch_1 = build_table_i32(
- ("a1", &vec![0, 1, 2]),
- ("b1", &vec![3, 4, 5]),
- ("c1", &vec![4, 5, 6]),
- );
- let left_batch_2 = build_table_i32(
- ("a1", &vec![3, 4, 5, 6]),
- ("b1", &vec![6, 6, 7, 9]),
- ("c1", &vec![7, 8, 9, 9]),
- );
- let right_batch_1 = build_table_i32(
- ("a2", &vec![0, 10, 20]),
- ("b2", &vec![2, 4, 6]),
- ("c2", &vec![50, 60, 70]),
- );
- let right_batch_2 = build_table_i32(
- ("a2", &vec![30, 40]),
- ("b2", &vec![6, 8]),
- ("c2", &vec![80, 90]),
- );
- let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
- let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Full).await?;
- let expected = vec![
- "+----+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | b2 | c2 |",
- "+----+----+----+----+----+----+",
- "| | | | 0 | 2 | 50 |",
- "| | | | 40 | 8 | 90 |",
- "| 0 | 3 | 4 | | | |",
- "| 1 | 4 | 5 | 10 | 4 | 60 |",
- "| 2 | 5 | 6 | | | |",
- "| 3 | 6 | 7 | 20 | 6 | 70 |",
- "| 3 | 6 | 7 | 30 | 6 | 80 |",
- "| 4 | 6 | 8 | 20 | 6 | 70 |",
- "| 4 | 6 | 8 | 30 | 6 | 80 |",
- "| 5 | 7 | 9 | | | |",
- "| 6 | 9 | 9 | | | |",
- "+----+----+----+----+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
-
- #[tokio::test]
- async fn join_existence_multiple_batches() -> Result<()> {
- let left_batch_1 = build_table_i32(
- ("a1", &vec![0, 1, 2]),
- ("b1", &vec![3, 4, 5]),
- ("c1", &vec![4, 5, 6]),
- );
- let left_batch_2 = build_table_i32(
- ("a1", &vec![3, 4, 5, 6]),
- ("b1", &vec![6, 6, 7, 9]),
- ("c1", &vec![7, 8, 9, 9]),
- );
- let right_batch_1 = build_table_i32(
- ("a2", &vec![0, 10, 20]),
- ("b2", &vec![2, 4, 6]),
- ("c2", &vec![50, 60, 70]),
- );
- let right_batch_2 = build_table_i32(
- ("a2", &vec![30, 40]),
- ("b2", &vec![6, 8]),
- ("c2", &vec![80, 90]),
- );
- let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
- let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
- let on: JoinOn = vec![(
- Arc::new(Column::new_with_schema("b1", &left.schema())?),
- Arc::new(Column::new_with_schema("b2", &right.schema())?),
- )];
-
- let (_, batches) = join_collect(left, right, on, Existence).await?;
- let expected = vec![
- "+----+----+----+----------+",
- "| a1 | b1 | c1 | exists#0 |",
- "+----+----+----+----------+",
- "| 0 | 3 | 4 | false |",
- "| 1 | 4 | 5 | true |",
- "| 2 | 5 | 6 | false |",
- "| 3 | 6 | 7 | true |",
- "| 4 | 6 | 8 | true |",
- "| 5 | 7 | 9 | false |",
- "| 6 | 9 | 9 | false |",
- "+----+----+----+----------+",
- ];
- assert_batches_sorted_eq!(expected, &batches);
- Ok(())
- }
+ fn num_output_rows(&self) -> usize;
}
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
index 1d867a5..ba127ef 100644
--- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
+++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
@@ -150,7 +150,7 @@
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- condition: Option[Expression]): NativeBroadcastJoinBase =
+ broadcastSide: BroadcastSide): NativeBroadcastJoinBase =
NativeBroadcastJoinExec(
left,
right,
@@ -158,7 +158,7 @@
leftKeys,
rightKeys,
joinType,
- condition)
+ broadcastSide)
override def createNativeBroadcastNestedLoopJoinExec(
left: SparkPlan,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
index 3fc6649..de3f5f8 100644
--- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
+++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
@@ -21,12 +21,16 @@
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
+import org.apache.spark.sql.catalyst.optimizer.BuildRight
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.blaze.plan.BroadcastLeft
+import org.apache.spark.sql.execution.blaze.plan.BroadcastRight
+import org.apache.spark.sql.execution.blaze.plan.BroadcastSide
import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.execution.joins.HashedRelationInfo
@@ -39,7 +43,7 @@
override val leftKeys: Seq[Expression],
override val rightKeys: Seq[Expression],
override val joinType: JoinType,
- override val condition: Option[Expression])
+ broadcastSide: BroadcastSide)
extends NativeBroadcastJoinBase(
left,
right,
@@ -47,9 +51,11 @@
leftKeys,
rightKeys,
joinType,
- condition)
+ broadcastSide)
with HashJoin {
+ override def condition: Option[Expression] = None
+
override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false)
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
@@ -65,7 +71,10 @@
throw new NotImplementedError("NativeBroadcastJoin dose not support codegen")
}
- override def buildSide: BuildSide = BuildLeft
+ override def buildSide: BuildSide = broadcastSide match {
+ case BroadcastLeft => BuildLeft
+ case BroadcastRight => BuildRight
+ }
override protected def withNewChildrenInternal(
newLeft: SparkPlan,
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
index 3a73a89..361c633 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
@@ -337,8 +337,8 @@
logDebug(s"Converting SortMergeJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}")
Shims.get.createNativeSortMergeJoinExec(
- convertToNative(left),
- convertToNative(right),
+ addRenameColumnsExec(convertToNative(left)),
+ addRenameColumnsExec(convertToNative(right)),
leftKeys,
rightKeys,
joinType,
@@ -356,84 +356,33 @@
exec.left,
exec.right)
logDebug(s"Converting BroadcastHashJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}")
- logDebug(s" leftKeys: ${exec.leftKeys}")
- logDebug(s" rightKeys: ${exec.rightKeys}")
- logDebug(s" joinType: ${exec.joinType}")
- logDebug(s" buildSide: ${exec.buildSide}")
- logDebug(s" condition: ${exec.condition}")
- var (hashed, hashedKeys, nativeProbed, probedKeys) = buildSide match {
+ logDebug(s" leftKeys: $leftKeys")
+ logDebug(s" rightKeys: $rightKeys")
+ logDebug(s" joinType: $joinType")
+ logDebug(s" buildSide: $buildSide")
+ logDebug(s" condition: $condition")
+ assert(condition.isEmpty, "join condition is not supported")
+
+ // verify build side is native
+ buildSide match {
case BuildRight =>
assert(NativeHelper.isNative(right), "broadcast join build side is not native")
- val convertedLeft = convertToNative(left)
- (right, rightKeys, convertedLeft, leftKeys)
-
case BuildLeft =>
assert(NativeHelper.isNative(left), "broadcast join build side is not native")
- val convertedRight = convertToNative(right)
- (left, leftKeys, convertedRight, rightKeys)
-
- case _ =>
- // scalastyle:off throwerror
- throw new NotImplementedError(
- "Ignore BroadcastHashJoin with unsupported children structure")
}
- var modifiedHashedKeys = hashedKeys
- var modifiedProbedKeys = probedKeys
- var needPostProject = false
+ Shims.get.createNativeBroadcastJoinExec(
+ addRenameColumnsExec(convertToNative(left)),
+ addRenameColumnsExec(convertToNative(right)),
+ exec.outputPartitioning,
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide match {
+ case BuildLeft => BroadcastLeft
+ case BuildRight => BroadcastRight
+ })
- if (hashedKeys.exists(!_.isInstanceOf[AttributeReference])) {
- val (keys, exec) = buildJoinColumnsProject(hashed, hashedKeys)
- modifiedHashedKeys = keys
- hashed = exec
- needPostProject = true
- }
- if (probedKeys.exists(!_.isInstanceOf[AttributeReference])) {
- val (keys, exec) = buildJoinColumnsProject(nativeProbed, probedKeys)
- modifiedProbedKeys = keys
- nativeProbed = exec
- needPostProject = true
- }
-
- val modifiedJoinType = buildSide match {
- case BuildLeft => joinType
- case BuildRight =>
- needPostProject = true
- val modifiedJoinType = joinType match { // reverse join type
- case Inner => Inner
- case FullOuter => FullOuter
- case LeftOuter => RightOuter
- case RightOuter => LeftOuter
- case _ =>
- throw new NotImplementedError(
- "BHJ Semi/Anti join with BuildRight is not yet supported")
- }
- modifiedJoinType
- }
-
- val bhjOrig = BroadcastHashJoinExec(
- modifiedHashedKeys,
- modifiedProbedKeys,
- modifiedJoinType,
- BuildLeft,
- condition,
- addRenameColumnsExec(hashed),
- addRenameColumnsExec(nativeProbed))
-
- val bhj = Shims.get.createNativeBroadcastJoinExec(
- bhjOrig.left,
- bhjOrig.right,
- bhjOrig.outputPartitioning,
- bhjOrig.leftKeys,
- bhjOrig.rightKeys,
- bhjOrig.joinType,
- bhjOrig.condition)
-
- if (needPostProject) {
- buildPostJoinProject(bhj, exec.output)
- } else {
- bhj
- }
} catch {
case e @ (_: NotImplementedError | _: Exception) =>
val underlyingBroadcast = exec.buildSide match {
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
index a8aaad2..fe2bd2c 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
@@ -79,7 +79,7 @@
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- condition: Option[Expression]): NativeBroadcastJoinBase
+ broadcastSide: BroadcastSide): NativeBroadcastJoinBase
def createNativeBroadcastNestedLoopJoinExec(
left: SparkPlan,
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala
index a0b8351..1f10768 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala
@@ -61,6 +61,7 @@
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.BinaryType
abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val child: SparkPlan)
extends BroadcastExchangeLike
@@ -91,9 +92,6 @@
override def doPrepare(): Unit = {
// Materialize the future.
relationFuture
- relationFuture
- relationFuture
- relationFuture
}
override def doExecuteBroadcast[T](): Broadcast[T] = {
@@ -152,13 +150,33 @@
Channels.newChannel(new ByteArrayInputStream(bytes))
})
}
+
+ // native hash map schema = nullable native schema + key column
+ val nativeHashMapSchema = pb.Schema
+ .newBuilder()
+ .addAllColumns(nativeSchema
+ .getColumnsList
+ .asScala
+ .map(field => pb.Field
+ .newBuilder()
+ .setName(field.getName)
+ .setArrowType(field.getArrowType)
+ .setNullable(true)
+ .build())
+ .asJava)
+ .addColumns(pb.Field
+ .newBuilder()
+ .setName("~key")
+ .setArrowType(NativeConverters.convertDataType(BinaryType))
+ .setNullable(true))
+
JniBridge.resourcesMap.put(resourceId, () => provideIpcIterator())
pb.PhysicalPlanNode
.newBuilder()
.setIpcReader(
pb.IpcReaderExecNode
.newBuilder()
- .setSchema(nativeSchema)
+ .setSchema(nativeHashMapSchema)
.setNumPartitions(1)
.setIpcProviderResourceId(resourceId)
.build())
@@ -265,39 +283,21 @@
keys: Seq[Expression],
nativeSchema: pb.Schema): Array[Array[Byte]] = {
- if (!BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || keys.isEmpty) {
- return collectedData // no need to sort data in driver side
- }
-
val readerIpcProviderResourceId = s"BuildBroadcastDataReader:${UUID.randomUUID()}"
val readerExec = pb.IpcReaderExecNode
.newBuilder()
.setSchema(nativeSchema)
.setIpcProviderResourceId(readerIpcProviderResourceId)
- val sortExec = pb.SortExecNode
+ val buildHashMapExec = pb.BroadcastJoinBuildHashMapExecNode
.newBuilder()
.setInput(pb.PhysicalPlanNode.newBuilder().setIpcReader(readerExec))
- .addAllExpr(
- keys
- .map(key => {
- pb.PhysicalExprNode
- .newBuilder()
- .setSort(
- pb.PhysicalSortExprNode
- .newBuilder()
- .setExpr(NativeConverters.convertExpr(key))
- .setAsc(true)
- .setNullsFirst(true)
- .build())
- .build()
- })
- .asJava)
+ .addAllKeys(keys.map(key => NativeConverters.convertExpr(key)).asJava)
val writerIpcProviderResourceId = s"BuildBroadcastDataWriter:${UUID.randomUUID()}"
val writerExec = pb.IpcWriterExecNode
.newBuilder()
- .setInput(pb.PhysicalPlanNode.newBuilder().setSort(sortExec))
+ .setInput(pb.PhysicalPlanNode.newBuilder().setBroadcastJoinBuildHashMap(buildHashMapExec))
.setIpcConsumerResourceId(writerIpcProviderResourceId)
// build native sorter
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala
index db0e7cb..8556d88 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala
@@ -41,17 +41,17 @@
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- condition: Option[Expression])
+ broadcastSide: BroadcastSide)
extends BinaryExecNode
with NativeSupports {
- assert(condition.isEmpty, "join filter is not supported")
-
override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map(
NativeHelper
.getDefaultNativeMetrics(sparkContext)
.toSeq: _*)
+ private def nativeSchema = Util.getNativeSchema(output)
+
private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) =>
val leftKeyExpr = NativeConverters.convertExpr(leftKey)
val rightKeyExpr = NativeConverters.convertExpr(rightKey)
@@ -64,44 +64,67 @@
private def nativeJoinType = NativeConverters.convertJoinType(joinType)
- private def nativeJoinFilter =
- condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output))
+ private def nativeBroadcastSide = broadcastSide match {
+ case BroadcastLeft => pb.JoinSide.LEFT_SIDE
+ case BroadcastRight => pb.JoinSide.RIGHT_SIDE
+ }
// check whether native converting is supported
+ nativeSchema
nativeJoinType
- nativeJoinFilter
+ nativeJoinOn
+ nativeBroadcastSide
override def doExecuteNative(): NativeRDD = {
val leftRDD = NativeHelper.executeNative(left)
val rightRDD = NativeHelper.executeNative(right)
val nativeMetrics = MetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil)
+ val nativeSchema = this.nativeSchema
val nativeJoinType = this.nativeJoinType
val nativeJoinOn = this.nativeJoinOn
- val nativeJoinFilter = this.nativeJoinFilter
- val partitions = rightRDD.partitions
+ val partitions = broadcastSide match {
+ case BroadcastLeft => rightRDD.partitions
+ case BroadcastRight => leftRDD.partitions
+ }
new NativeRDD(
sparkContext,
nativeMetrics,
partitions,
- rddDependencies = new OneToOneDependency(rightRDD) :: Nil,
+ rddDependencies = broadcastSide match {
+ case BroadcastLeft => new OneToOneDependency(rightRDD) :: Nil
+ case BroadcastRight => new OneToOneDependency(leftRDD) :: Nil
+ },
rightRDD.isShuffleReadFull,
(partition, context) => {
val partition0 = new Partition() {
override def index: Int = 0
}
- val leftChild = leftRDD.nativePlan(partition0, context)
- val rightChild = rightRDD.nativePlan(rightRDD.partitions(partition.index), context)
+ val (leftChild, rightChild) = broadcastSide match {
+ case BroadcastLeft => (
+ leftRDD.nativePlan(partition0, context),
+ rightRDD.nativePlan(rightRDD.partitions(partition.index), context),
+ )
+ case BroadcastRight => (
+ leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
+ rightRDD.nativePlan(partition0, context),
+ )
+ }
val broadcastJoinExec = pb.BroadcastJoinExecNode
.newBuilder()
+ .setSchema(nativeSchema)
.setLeft(leftChild)
.setRight(rightChild)
.setJoinType(nativeJoinType)
+ .setBroadcastSide(nativeBroadcastSide)
.addAllOn(nativeJoinOn.asJava)
- nativeJoinFilter.foreach(joinFilter => broadcastJoinExec.setJoinFilter(joinFilter))
pb.PhysicalPlanNode.newBuilder().setBroadcastJoin(broadcastJoinExec).build()
},
friendlyName = "NativeRDD.BroadcastJoin")
}
}
+
+class BroadcastSide {}
+case object BroadcastLeft extends BroadcastSide {}
+case object BroadcastRight extends BroadcastSide {}