joins refactoring
diff --git a/Cargo.lock b/Cargo.lock index ef23afd..51ae230 100644 --- a/Cargo.lock +++ b/Cargo.lock
@@ -923,6 +923,7 @@ "arrow", "async-trait", "base64 0.22.1", + "bitvec", "blaze-jni-bridge", "byteorder", "bytes",
diff --git a/native-engine/blaze-jni-bridge/src/conf.rs b/native-engine/blaze-jni-bridge/src/conf.rs index dd476ed..9eccc0e 100644 --- a/native-engine/blaze-jni-bridge/src/conf.rs +++ b/native-engine/blaze-jni-bridge/src/conf.rs
@@ -41,6 +41,8 @@ define_conf!(BooleanConf, PARTIAL_AGG_SKIPPING_ENABLE); define_conf!(DoubleConf, PARTIAL_AGG_SKIPPING_RATIO); define_conf!(IntConf, PARTIAL_AGG_SKIPPING_MIN_ROWS); +define_conf!(BooleanConf, PARQUET_ENABLE_PAGE_FILTERING); +define_conf!(BooleanConf, PARQUET_ENABLE_BLOOM_FILTER); pub trait BooleanConf { fn key(&self) -> &'static str;
diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 9818a13..3e096a8 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto
@@ -35,19 +35,20 @@ FilterExecNode filter = 8; UnionExecNode union = 9; SortMergeJoinExecNode sort_merge_join = 10; - BroadcastJoinExecNode broadcast_join = 11; - RenameColumnsExecNode rename_columns = 12; - EmptyPartitionsExecNode empty_partitions = 13; - AggExecNode agg = 14; - LimitExecNode limit = 15; - FFIReaderExecNode ffi_reader = 16; - CoalesceBatchesExecNode coalesce_batches = 17; - ExpandExecNode expand = 18; - RssShuffleWriterExecNode rss_shuffle_writer= 19; - WindowExecNode window = 20; - GenerateExecNode generate = 21; - ParquetSinkExecNode parquet_sink = 22; - BroadcastNestedLoopJoinExecNode broadcast_nested_loop_join = 23; + BroadcastJoinBuildHashMapExecNode broadcast_join_build_hash_map = 11; + BroadcastJoinExecNode broadcast_join = 12; + RenameColumnsExecNode rename_columns = 13; + EmptyPartitionsExecNode empty_partitions = 14; + AggExecNode agg = 15; + LimitExecNode limit = 16; + FFIReaderExecNode ffi_reader = 17; + CoalesceBatchesExecNode coalesce_batches = 18; + ExpandExecNode expand = 19; + RssShuffleWriterExecNode rss_shuffle_writer= 20; + WindowExecNode window = 21; + GenerateExecNode generate = 22; + ParquetSinkExecNode parquet_sink = 23; + BroadcastNestedLoopJoinExecNode broadcast_nested_loop_join = 24; } } @@ -398,20 +399,28 @@ } message SortMergeJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - repeated JoinOn on = 3; - repeated SortOptions sort_options = 4; - JoinType join_type = 5; - JoinFilter join_filter = 6; + Schema schema = 1; + PhysicalPlanNode left = 2; + PhysicalPlanNode right = 3; + repeated JoinOn on = 4; + repeated SortOptions sort_options = 5; + JoinType join_type = 6; + JoinFilter join_filter = 7; +} + +message BroadcastJoinBuildHashMapExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalExprNode keys =2; } message BroadcastJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - repeated JoinOn on = 3; - JoinType join_type = 4; - JoinFilter join_filter = 5; + Schema schema = 1; + PhysicalPlanNode left = 2; + PhysicalPlanNode right = 3; + repeated JoinOn on = 4; + JoinType join_type = 5; + JoinSide broadcast_side = 6; + string cached_build_hash_map_id = 7; } message BroadcastNestedLoopJoinExecNode { @@ -438,6 +447,7 @@ FULL = 3; SEMI = 4; ANTI = 5; + EXISTENCE = 6; } message SortOptions { @@ -456,8 +466,8 @@ } message JoinOn { - PhysicalColumn left = 1; - PhysicalColumn right = 2; + PhysicalExprNode left = 1; + PhysicalExprNode right = 2; } message ProjectionExecNode {
diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index 1f4e824..bbd97c2 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs
@@ -61,6 +61,7 @@ use datafusion_ext_plans::{ agg::{create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr}, agg_exec::AggExec, + broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, broadcast_join_exec::BroadcastJoinExec, broadcast_nested_loop_join_exec::BroadcastNestedLoopJoinExec, debug_exec::DebugExec, @@ -72,6 +73,7 @@ generate_exec::GenerateExec, ipc_reader_exec::IpcReaderExec, ipc_writer_exec::IpcWriterExec, + joins::join_utils::JoinType, limit_exec::LimitExec, parquet_exec::ParquetExec, parquet_sink_exec::ParquetSinkExec, @@ -89,7 +91,7 @@ use crate::{ convert_box_required, convert_required, error::PlanSerDeError, - from_proto_binary_op, into_required, proto_error, protobuf, + from_proto_binary_op, proto_error, protobuf, protobuf::{ physical_expr_node::ExprType, physical_plan_node::PhysicalPlanType, GenerateFunction, }, @@ -182,19 +184,20 @@ ))) } PhysicalPlanType::SortMergeJoin(sort_merge_join) => { + let schema = Arc::new(convert_required!(sort_merge_join.schema)?); let left: Arc<dyn ExecutionPlan> = convert_box_required!(sort_merge_join.left)?; let right: Arc<dyn ExecutionPlan> = convert_box_required!(sort_merge_join.right)?; let on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = sort_merge_join .on .iter() .map(|col| { - let left_col: Column = into_required!(col.left)?; - let left_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(left_col.name(), &left.schema())?); - let right_col: Column = into_required!(col.right)?; - let right_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(right_col.name(), &right.schema())?); - Ok((left_col_binded, right_col_binded)) + let left_key = + try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let left_key_binded = bind(left_key, &left.schema())?; + let right_key = + try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let right_key_binded = bind(right_key, &right.schema())?; + Ok((left_key_binded, right_key_binded)) }) .collect::<Result<_, Self::Error>>()?; @@ -210,38 +213,14 @@ let join_type = protobuf::JoinType::try_from(sort_merge_join.join_type) .expect("invalid JoinType"); - let join_filter = sort_merge_join - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::<Result<Vec<_>, PlanSerDeError>>()?; - - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; Ok(Arc::new(SortMergeJoinExec::try_new( + schema, left, right, on, - join_type.into(), - join_filter, + join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?, sort_options, )?)) } @@ -306,7 +285,7 @@ self )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -342,58 +321,58 @@ sort.fetch_limit.as_ref().map(|limit| limit.limit as usize), ))) } + PhysicalPlanType::BroadcastJoinBuildHashMap(bhm) => { + let input: Arc<dyn ExecutionPlan> = convert_box_required!(bhm.input)?; + let keys = bhm + .keys + .iter() + .map(|expr| { + Ok(bind( + try_parse_physical_expr(expr, &input.schema())?, + &input.schema(), + )?) + }) + .collect::<Result<Vec<Arc<dyn PhysicalExpr>>, Self::Error>>()?; + Ok(Arc::new(BroadcastJoinBuildHashMapExec::new(input, keys))) + } PhysicalPlanType::BroadcastJoin(broadcast_join) => { + let schema = Arc::new(convert_required!(broadcast_join.schema)?); let left: Arc<dyn ExecutionPlan> = convert_box_required!(broadcast_join.left)?; let right: Arc<dyn ExecutionPlan> = convert_box_required!(broadcast_join.right)?; let on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = broadcast_join .on .iter() .map(|col| { - let left_col: Column = into_required!(col.left)?; - let left_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(left_col.name(), &left.schema())?); - let right_col: Column = into_required!(col.right)?; - let right_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(right_col.name(), &right.schema())?); - Ok((left_col_binded, right_col_binded)) + let left_key = + try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let left_key_binded = bind(left_key, &left.schema())?; + let right_key = + try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let right_key_binded = bind(right_key, &right.schema())?; + Ok((left_key_binded, right_key_binded)) }) .collect::<Result<_, Self::Error>>()?; let join_type = protobuf::JoinType::try_from(broadcast_join.join_type) .expect("invalid JoinType"); - let join_filter = broadcast_join - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::<Result<Vec<_>, PlanSerDeError>>()?; - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; + let broadcast_side = protobuf::JoinSide::try_from(broadcast_join.broadcast_side) + .expect("invalid BroadcastSide"); + + let cached_build_hash_map_id = broadcast_join.cached_build_hash_map_id.clone(); Ok(Arc::new(BroadcastJoinExec::try_new( + schema, left, right, on, - join_type.into(), - join_filter, + join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?, + broadcast_side + .try_into() + .map_err(|_| proto_error("invalid BroadcastSide"))?, + Some(cached_build_hash_map_id), )?)) } PhysicalPlanType::BroadcastNestedLoopJoin(bnlj) => { @@ -428,10 +407,15 @@ }) .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; + let blaze_join_type: JoinType = join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?; Ok(Arc::new(BroadcastNestedLoopJoinExec::try_new( left, right, - join_type.into(), + blaze_join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?, join_filter, )?)) }
diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs index 30bd4c2..56cd4a6 100644 --- a/native-engine/blaze-serde/src/lib.rs +++ b/native-engine/blaze-serde/src/lib.rs
@@ -15,10 +15,8 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema, TimeUnit}; -use datafusion::{ - common::JoinSide, logical_expr::Operator, prelude::JoinType, scalar::ScalarValue, -}; -use datafusion_ext_plans::agg::AggFunction; +use datafusion::{common::JoinSide, logical_expr::Operator, scalar::ScalarValue}; +use datafusion_ext_plans::{agg::AggFunction, joins::join_utils::JoinType}; use crate::error::PlanSerDeError; @@ -111,6 +109,7 @@ protobuf::JoinType::Full => JoinType::Full, protobuf::JoinType::Semi => JoinType::LeftSemi, protobuf::JoinType::Anti => JoinType::LeftAnti, + protobuf::JoinType::Existence => JoinType::Existence, } } }
diff --git a/native-engine/datafusion-ext-plans/Cargo.toml b/native-engine/datafusion-ext-plans/Cargo.toml index a233274..82fad59 100644 --- a/native-engine/datafusion-ext-plans/Cargo.toml +++ b/native-engine/datafusion-ext-plans/Cargo.toml
@@ -11,6 +11,7 @@ arrow = { workspace = true } async-trait = "0.1.80" base64 = "0.22.1" +bitvec = "1.0.1" byteorder = "1.5.0" bytes = "1.6.0" blaze-jni-bridge = { workspace = true }
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs new file mode 100644 index 0000000..3f1ca6d --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs
@@ -0,0 +1,150 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{compute::concat_batches, datatypes::SchemaRef}; +use datafusion::{ + common::Result, + execution::{SendableRecordBatchStream, TaskContext}, + physical_expr::{Partitioning, PhysicalExpr, PhysicalSortExpr}, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, + }, +}; +use futures::{stream::once, TryStreamExt}; + +use crate::{ + common::output::{NextBatchWithTimer, TaskOutputter}, + joins::join_hash_map::{join_hash_map_schema, JoinHashMap}, +}; + +pub struct BroadcastJoinBuildHashMapExec { + input: Arc<dyn ExecutionPlan>, + keys: Vec<Arc<dyn PhysicalExpr>>, + metrics: ExecutionPlanMetricsSet, +} + +impl BroadcastJoinBuildHashMapExec { + pub fn new(input: Arc<dyn ExecutionPlan>, keys: Vec<Arc<dyn PhysicalExpr>>) -> Self { + Self { + input, + keys, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl Debug for BroadcastJoinBuildHashMapExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "BroadcastJoinBuildHashMap [{:?}]", self.keys) + } +} + +impl DisplayAs for BroadcastJoinBuildHashMapExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "BroadcastJoinBuildHashMapExec [{:?}]", self.keys) + } +} + +impl ExecutionPlan for BroadcastJoinBuildHashMapExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + join_hash_map_schema(&self.input.schema()) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.input.output_partitioning().partition_count()) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc<Self>, + children: Vec<Arc<dyn ExecutionPlan>>, + ) -> Result<Arc<dyn ExecutionPlan>> { + Ok(Arc::new(Self::new(children[0].clone(), self.keys.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc<TaskContext>, + ) -> Result<SendableRecordBatchStream> { + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let input = self.input.execute(partition, context.clone())?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + once(execute_build_hash_map( + context, + input, + self.keys.clone(), + baseline_metrics, + )) + .try_flatten(), + ))) + } + + fn metrics(&self) -> Option<MetricsSet> { + Some(self.metrics.clone_inner()) + } +} + +async fn execute_build_hash_map( + context: Arc<TaskContext>, + mut input: SendableRecordBatchStream, + keys: Vec<Arc<dyn PhysicalExpr>>, + metrics: BaselineMetrics, +) -> Result<SendableRecordBatchStream> { + let elapsed_compute = metrics.elapsed_compute().clone(); + let mut timer = elapsed_compute.timer(); + + let mut data_batches = vec![]; + let data_schema = input.schema(); + + // collect all input batches + while let Some(batch) = input.next_batch(Some(&mut timer)).await? { + data_batches.push(batch); + } + let data_batch = concat_batches(&data_schema, data_batches.iter())?; + + // build hash map + let hash_map_schema = join_hash_map_schema(&data_schema); + let hash_map = JoinHashMap::try_from_data_batch(data_batch, &keys)?; + drop(timer); + + // output hash map batches as stream + context.output_with_sender("BuildHashMap", hash_map_schema, move |sender| async move { + let mut timer = elapsed_compute.timer(); + sender + .send(Ok(hash_map.into_hash_map_batch()?), Some(&mut timer)) + .await; + Ok(()) + }) +}
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs index 201173c..cc7b5d4 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
@@ -15,90 +15,134 @@ use std::{ any::Any, fmt::{Debug, Formatter}, - sync::Arc, - task::Poll, - time::Duration, + future::Future, + pin::Pin, + sync::{Arc, Weak}, + time::{Duration, Instant}, }; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use blaze_jni_bridge::{ - conf, - conf::{BooleanConf, IntConf}, +use arrow::{ + array::RecordBatch, + compute::SortOptions, + datatypes::{DataType, SchemaRef}, }; +use async_trait::async_trait; use datafusion::{ - common::{Result, Statistics}, + common::{JoinSide, Result, Statistics}, execution::context::TaskContext, - logical_expr::JoinType, - physical_expr::PhysicalSortExpr, + physical_expr::{PhysicalExprRef, PhysicalSortExpr}, physical_plan::{ - expressions::Column, - joins::{ - utils::{build_join_schema, check_join_is_valid, JoinFilter, JoinOn}, - HashJoinExec, PartitionMode, - }, - memory::MemoryStream, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + joins::utils::JoinOn, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, Time}, stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }, }; -use datafusion_ext_commons::{df_execution_err, downcast_any}; -use futures::{stream::once, StreamExt, TryStreamExt}; +use datafusion_ext_commons::{ + batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput, +}; +use futures::{StreamExt, TryStreamExt}; +use hashbrown::HashMap; +use once_cell::sync::OnceCell; use parking_lot::Mutex; -use crate::{sort_exec::SortExec, sort_merge_join_exec::SortMergeJoinExec}; +use crate::{ + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + output::{TaskOutputter, WrappedRecordBatchSender}, + }, + joins::{ + bhj::{ + full_join::{ + LProbedFullOuterJoiner, LProbedInnerJoiner, LProbedLeftJoiner, LProbedRightJoiner, + RProbedFullOuterJoiner, RProbedInnerJoiner, RProbedLeftJoiner, RProbedRightJoiner, + }, + semi_join::{ + LProbedExistenceJoiner, LProbedLeftAntiJoiner, LProbedLeftSemiJoiner, + LProbedRightAntiJoiner, LProbedRightSemiJoiner, RProbedExistenceJoiner, + RProbedLeftAntiJoiner, RProbedLeftSemiJoiner, RProbedRightAntiJoiner, + RProbedRightSemiJoiner, + }, + }, + join_hash_map::JoinHashMap, + join_utils::{JoinType, JoinType::*}, + JoinParams, + }, +}; #[derive(Debug)] pub struct BroadcastJoinExec { - /// Left sorted joining execution plan left: Arc<dyn ExecutionPlan>, - /// Right sorting joining execution plan right: Arc<dyn ExecutionPlan>, - /// Set of common columns used to join on on: JoinOn, - /// How the join is performed join_type: JoinType, - /// Optional filter before outputting - join_filter: Option<JoinFilter>, - /// The schema once the join is applied + broadcast_side: JoinSide, schema: SchemaRef, - /// Execution metrics + cached_build_hash_map_id: Option<String>, metrics: ExecutionPlanMetricsSet, } impl BroadcastJoinExec { pub fn try_new( + schema: SchemaRef, left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>, on: JoinOn, join_type: JoinType, - join_filter: Option<JoinFilter>, + broadcast_side: JoinSide, + cached_build_hash_map_id: Option<String>, ) -> Result<Self> { - if matches!( - join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti, - ) { - if join_filter.is_some() { - df_execution_err!("Semi/Anti join with filter is not supported yet")?; - } - } - - let left_schema = left.schema(); - let right_schema = right.schema(); - - check_join_is_valid(&left_schema, &right_schema, &on)?; - let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); - Ok(Self { left, right, on, join_type, - join_filter, + broadcast_side, schema, + cached_build_hash_map_id, metrics: ExecutionPlanMetricsSet::new(), }) } + + fn create_join_params(&self) -> Result<JoinParams> { + let left_schema = self.left.schema(); + let right_schema = self.right.schema(); + let (left_keys, right_keys): (Vec<PhysicalExprRef>, Vec<PhysicalExprRef>) = + self.on.iter().cloned().unzip(); + let key_data_types: Vec<DataType> = self + .on + .iter() + .map(|(left_key, right_key)| { + Ok({ + let left_dt = left_key.data_type(&left_schema)?; + let right_dt = right_key.data_type(&right_schema)?; + if left_dt != right_dt { + df_execution_err!( + "join key data type differs {left_dt:?} <-> {right_dt:?}" + )?; + } + left_dt + }) + }) + .collect::<Result<_>>()?; + + // use smaller batch size and coalesce batches at the end, to avoid buffer + // overflowing + let batch_size = batch_size(); + let sub_batch_size = batch_size / batch_size.ilog10() as usize; + + Ok(JoinParams { + join_type: self.join_type, + left_schema, + right_schema, + output_schema: self.schema(), + left_keys, + right_keys, + batch_size: sub_batch_size, + sort_options: vec![SortOptions::default(); self.on.len()], + key_data_types, + }) + } } impl ExecutionPlan for BroadcastJoinExec { @@ -111,7 +155,10 @@ } fn output_partitioning(&self) -> Partitioning { - self.right.output_partitioning() + match self.broadcast_side { + JoinSide::Left => self.right.output_partitioning(), + JoinSide::Right => self.left.output_partitioning(), + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { @@ -127,11 +174,13 @@ children: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>> { Ok(Arc::new(Self::try_new( + self.schema.clone(), children[0].clone(), children[1].clone(), self.on.iter().cloned().collect(), self.join_type, - self.join_filter.clone(), + self.broadcast_side, + None, )?)) } @@ -140,21 +189,42 @@ partition: usize, context: Arc<TaskContext>, ) -> Result<SendableRecordBatchStream> { - let stream = execute_broadcast_join( - self.left.clone(), - self.right.clone(), - partition, - context, - self.on.clone(), - self.join_type, - self.join_filter.clone(), - BaselineMetrics::new(&self.metrics, partition), - ); + let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); + let join_params = self.create_join_params()?; + let left = self.left.execute(partition, context.clone())?; + let right = self.right.execute(partition, context.clone())?; + let broadcast_side = self.broadcast_side; + let cached_build_hash_map_id = self.cached_build_hash_map_id.clone(); + let output_schema = self.schema(); - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - once(stream).try_flatten(), - ))) + // stat probed side + let input_batch_stat = + InputBatchStatistics::from_metrics_set_and_blaze_conf(&self.metrics, partition)?; + let (left, right) = match broadcast_side { + JoinSide::Left => (left, stat_input(input_batch_stat, right)?), + JoinSide::Right => (stat_input(input_batch_stat, left)?, right), + }; + + let metrics_cloned = metrics.clone(); + let context_cloned = context.clone(); + let output_stream = Box::pin(RecordBatchStreamAdapter::new( + output_schema.clone(), + futures::stream::once(async move { + context_cloned.output_with_sender("BroadcastJoin", output_schema, move |sender| { + execute_join( + left, + right, + join_params, + broadcast_side, + cached_build_hash_map_id, + metrics_cloned, + sender, + ) + }) + }) + .try_flatten(), + )); + Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) } fn metrics(&self) -> Option<MetricsSet> { @@ -172,221 +242,188 @@ } } -async fn execute_broadcast_join( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - partition: usize, - context: Arc<TaskContext>, - on: JoinOn, - join_type: JoinType, - join_filter: Option<JoinFilter>, - metrics: BaselineMetrics, -) -> Result<SendableRecordBatchStream> { - let enabled_fallback_to_smj = conf::BHJ_FALLBACKS_TO_SMJ_ENABLE.value()?; - let bhj_num_rows_limit = conf::BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD.value()? as usize; - let bhj_mem_size_limit = conf::BHJ_FALLBACKS_TO_SMJ_MEM_THRESHOLD.value()? as usize; +async fn execute_join( + left: SendableRecordBatchStream, + right: SendableRecordBatchStream, + join_params: JoinParams, + broadcast_side: JoinSide, + cached_build_hash_map_id: Option<String>, + metrics: Arc<BaselineMetrics>, + sender: Arc<WrappedRecordBatchSender>, +) -> Result<()> { + let start_time = Instant::now(); + let mut excluded_time_ns = 0; + let poll_time = Time::new(); - // if broadcasted size is small enough, use hash join - // otherwise use sort-merge join - #[derive(Debug)] - enum JoinMode { - Hash, - SortMerge, - } - let mut join_mode = JoinMode::Hash; - - let left_schema = left.schema(); - let mut left = left; - - if enabled_fallback_to_smj { - let mut left_stream = left.execute(0, context.clone())?.fuse(); - let mut left_cached: Vec<RecordBatch> = vec![]; - let mut left_num_rows = 0; - let mut left_mem_size = 0; - - // read and cache batches from broadcasted side until reached limits - while let Some(batch) = left_stream.next().await.transpose()? { - left_num_rows += batch.num_rows(); - left_mem_size += batch.get_array_memory_size(); - left_cached.push(batch); - if left_num_rows > bhj_num_rows_limit || left_mem_size > bhj_mem_size_limit { - join_mode = JoinMode::SortMerge; - break; - } + let (mut probed, _keys, mut joiner): (_, _, Pin<Box<dyn Joiner + Send>>) = match broadcast_side + { + JoinSide::Left => { + let right_schema = right.schema(); + let mut right_peeked = Box::pin(right.peekable()); + let (_, lmap_result) = futures::join!( + // fetch two sides asynchronously + async { + let timer = poll_time.timer(); + right_peeked.as_mut().peek().await; + drop(timer); + }, + collect_join_hash_map( + cached_build_hash_map_id, + left, + &join_params.left_keys, + poll_time.clone(), + ), + ); + let lmap = lmap_result?; + ( + Box::pin(RecordBatchStreamAdapter::new(right_schema, right_peeked)), + join_params.right_keys.clone(), + match join_params.join_type { + Inner => Box::pin(RProbedInnerJoiner::new(join_params, lmap, sender)), + Left => Box::pin(RProbedLeftJoiner::new(join_params, lmap, sender)), + Right => Box::pin(RProbedRightJoiner::new(join_params, lmap, sender)), + Full => Box::pin(RProbedFullOuterJoiner::new(join_params, lmap, sender)), + LeftSemi => Box::pin(RProbedLeftSemiJoiner::new(join_params, lmap, sender)), + LeftAnti => Box::pin(RProbedLeftAntiJoiner::new(join_params, lmap, sender)), + RightSemi => Box::pin(RProbedRightSemiJoiner::new(join_params, lmap, sender)), + RightAnti => Box::pin(RProbedRightAntiJoiner::new(join_params, lmap, sender)), + Existence => Box::pin(RProbedExistenceJoiner::new(join_params, lmap, sender)), + }, + ) } - - // convert left cached and rest batches into execution plan - let left_cached_stream: SendableRecordBatchStream = Box::pin(MemoryStream::try_new( - left_cached, - left_schema.clone(), - None, - )?); - let left_rest_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - left_schema.clone(), - left_stream, - )); - let left_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - left_schema.clone(), - left_cached_stream.chain(left_rest_stream), - )); - left = Arc::new(RecordBatchStreamsWrapperExec { - schema: left_schema.clone(), - stream: Mutex::new(Some(left_stream)), - output_partitioning: right.output_partitioning(), - }); - } - - match join_mode { - JoinMode::Hash => { - let join = Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), - on, - join_filter, - &join_type, - PartitionMode::CollectLeft, - false, - )?); - log::info!("BroadcastJoin is using hash join mode: {:?}", &join); - - let join_schema = join.schema(); - let completed = join - .execute(partition, context)? - .chain(futures::stream::poll_fn(move |_| { - // update metrics - let join_metrics = join.metrics().unwrap(); - metrics.record_output(join_metrics.output_rows().unwrap_or(0)); - metrics.elapsed_compute().add_duration(Duration::from_nanos( - [ - join_metrics - .sum_by_name("build_time") - .map(|v| v.as_usize() as u64), - join_metrics - .sum_by_name("join_time") - .map(|v| v.as_usize() as u64), - ] - .into_iter() - .flatten() - .sum(), - )); - Poll::Ready(None) - })); - Ok(Box::pin(RecordBatchStreamAdapter::new( - join_schema, - completed, - ))) + JoinSide::Right => { + let left_schema = left.schema(); + let mut left_peeked = Box::pin(left.peekable()); + let (_, rmap_result) = futures::join!( + // fetch two sides asynchronizely + async { + let timer = poll_time.timer(); + left_peeked.as_mut().peek().await; + drop(timer); + }, + collect_join_hash_map( + cached_build_hash_map_id, + right, + &join_params.right_keys, + poll_time.clone(), + ), + ); + let rmap = rmap_result?; + ( + Box::pin(RecordBatchStreamAdapter::new(left_schema, left_peeked)), + join_params.left_keys.clone(), + match join_params.join_type { + Inner => Box::pin(LProbedInnerJoiner::new(join_params, rmap, sender)), + Left => Box::pin(LProbedLeftJoiner::new(join_params, rmap, sender)), + Right => Box::pin(LProbedRightJoiner::new(join_params, rmap, sender)), + Full => Box::pin(LProbedFullOuterJoiner::new(join_params, rmap, sender)), + LeftSemi => Box::pin(LProbedLeftSemiJoiner::new(join_params, rmap, sender)), + LeftAnti => Box::pin(LProbedLeftAntiJoiner::new(join_params, rmap, sender)), + RightSemi => Box::pin(LProbedRightSemiJoiner::new(join_params, rmap, sender)), + RightAnti => Box::pin(LProbedRightAntiJoiner::new(join_params, rmap, sender)), + Existence => Box::pin(LProbedExistenceJoiner::new(join_params, rmap, sender)), + }, + ) } - JoinMode::SortMerge => { - let sort_exprs: Vec<PhysicalSortExpr> = on - .iter() - .map(|(_col_left, col_right)| PhysicalSortExpr { - expr: Arc::new(Column::new( - "", - downcast_any!(col_right, Column) - .expect("requires column") - .index(), - )), - options: Default::default(), - }) - .collect(); + }; - let right_sorted = Arc::new(SortExec::new(right, sort_exprs.clone(), None)); - let join = Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right_sorted.clone(), - on, - join_type, - join_filter, - sort_exprs.into_iter().map(|se| se.options).collect(), - )?); - log::info!("BroadcastJoin is using sort-merge join mode: {:?}", &join); + while let Some(batch) = { + let timer = poll_time.timer(); + let batch = probed.next().await.transpose()?; + drop(timer); + batch + } { + joiner.as_mut().join(batch).await?; + } + joiner.as_mut().finish().await?; + metrics.record_output(joiner.num_output_rows()); - let join_schema = join.schema(); - let completed = join - .execute(partition, context)? - .chain(futures::stream::poll_fn(move |_| { - // update metrics - let right_sorted_metrics = right_sorted.metrics().unwrap(); - let join_metrics = join.metrics().unwrap(); - metrics.record_output(join_metrics.output_rows().unwrap_or(0)); - metrics.elapsed_compute().add_duration(Duration::from_nanos( - [ - right_sorted_metrics.elapsed_compute(), - join_metrics.elapsed_compute(), - ] - .into_iter() - .flatten() - .sum::<usize>() as u64, - )); - Poll::Ready(None) - })); - Ok(Box::pin(RecordBatchStreamAdapter::new( - join_schema, - completed, - ))) + excluded_time_ns += poll_time.value(); + excluded_time_ns += joiner.total_send_output_time(); + + // discount poll input and send output batch time + let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64; + join_time_ns -= excluded_time_ns as u64; + metrics + .elapsed_compute() + .add_duration(Duration::from_nanos(join_time_ns)); + Ok(()) +} + +async fn collect_join_hash_map( + cached_build_hash_map_id: Option<String>, + input: SendableRecordBatchStream, + key_exprs: &[PhysicalExprRef], + poll_time: Time, +) -> Result<Arc<JoinHashMap>> { + Ok(match cached_build_hash_map_id { + Some(cached_id) => { + get_cached_join_hash_map(&cached_id, || async { + collect_join_hash_map_without_caching(input, key_exprs, poll_time).await + }) + .await? } + None => { + let map = collect_join_hash_map_without_caching(input, key_exprs, poll_time).await?; + Arc::new(map) + } + }) +} + +async fn collect_join_hash_map_without_caching( + mut input: SendableRecordBatchStream, + key_exprs: &[PhysicalExprRef], + poll_time: Time, +) -> Result<JoinHashMap> { + let mut hash_map_batches = vec![]; + while let Some(batch) = { + let timer = poll_time.timer(); + let batch = input.next().await.transpose()?; + drop(timer); + batch + } { + hash_map_batches.push(batch); + } + match hash_map_batches.len() { + 0 => Ok(JoinHashMap::try_new_empty(input.schema(), key_exprs)?), + 1 => Ok(JoinHashMap::try_from_hash_map_batch( + hash_map_batches[0].clone(), + key_exprs, + )?), + n => df_execution_err!("expect zero or one hash map batch, got {n}"), } } -pub struct RecordBatchStreamsWrapperExec { - pub schema: SchemaRef, - pub stream: Mutex<Option<SendableRecordBatchStream>>, - pub output_partitioning: Partitioning, +#[async_trait] +pub trait Joiner { + async fn join(self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()>; + async fn finish(self: Pin<&mut Self>) -> Result<()>; + + fn total_send_output_time(&self) -> usize; + fn num_output_rows(&self) -> usize; } -impl Debug for RecordBatchStreamsWrapperExec { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "RecordBatchStreamsWrapper") - } -} +async fn get_cached_join_hash_map<Fut: Future<Output = Result<JoinHashMap>> + Send>( + cached_id: &str, + init: impl FnOnce() -> Fut, +) -> Result<Arc<JoinHashMap>> { + type Slot = Arc<tokio::sync::Mutex<Weak<JoinHashMap>>>; + static CACHED_JOIN_HASH_MAP: OnceCell<Arc<Mutex<HashMap<String, Slot>>>> = OnceCell::new(); -impl DisplayAs for RecordBatchStreamsWrapperExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "RecordBatchStreamsWrapper") - } -} + // TODO: remove expired keys from cached join hash map + let cached_join_hash_map = CACHED_JOIN_HASH_MAP.get_or_init(|| Arc::default()); + let slot = cached_join_hash_map + .lock() + .entry(cached_id.to_string()) + .or_default() + .clone(); -impl ExecutionPlan for RecordBatchStreamsWrapperExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn output_partitioning(&self) -> Partitioning { - self.output_partitioning.clone() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { - vec![] - } - - fn with_new_children( - self: Arc<Self>, - _: Vec<Arc<dyn ExecutionPlan>>, - ) -> Result<Arc<dyn ExecutionPlan>> { - unimplemented!() - } - - fn execute( - &self, - _partition: usize, - _context: Arc<TaskContext>, - ) -> Result<SendableRecordBatchStream> { - let stream = std::mem::take(&mut *self.stream.lock()); - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), - Box::pin(futures::stream::iter(stream).flatten()), - ))) - } - - fn statistics(&self) -> Result<Statistics> { - unimplemented!() + let mut slot = slot.lock().await; + if let Some(cached) = slot.upgrade() { + Ok(cached) + } else { + let new = Arc::new(init().await?); + *slot = Arc::downgrade(&new); + Ok(new) } }
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs index b52e77f..7ddd741 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs
@@ -34,8 +34,6 @@ use futures::{stream::once, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use crate::broadcast_join_exec::RecordBatchStreamsWrapperExec; - #[derive(Debug)] pub struct BroadcastNestedLoopJoinExec { left: Arc<dyn ExecutionPlan>, @@ -250,3 +248,66 @@ JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full ) } + +struct RecordBatchStreamsWrapperExec { + pub schema: SchemaRef, + pub stream: Mutex<Option<SendableRecordBatchStream>>, + pub output_partitioning: Partitioning, +} + +impl std::fmt::Debug for RecordBatchStreamsWrapperExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "RecordBatchStreamsWrapper") + } +} + +impl DisplayAs for RecordBatchStreamsWrapperExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "RecordBatchStreamsWrapper") + } +} + +impl ExecutionPlan for RecordBatchStreamsWrapperExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + self.output_partitioning.clone() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { + vec![] + } + + fn with_new_children( + self: Arc<Self>, + _: Vec<Arc<dyn ExecutionPlan>>, + ) -> Result<Arc<dyn ExecutionPlan>> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc<TaskContext>, + ) -> Result<SendableRecordBatchStream> { + let stream = std::mem::take(&mut *self.stream.lock()); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + Box::pin(futures::stream::iter(stream).flatten()), + ))) + } + + fn statistics(&self) -> Result<Statistics> { + unimplemented!() + } +}
diff --git a/native-engine/datafusion-ext-plans/src/common/batch_selection.rs b/native-engine/datafusion-ext-plans/src/common/batch_selection.rs index 6aa8395..a5e789c 100644 --- a/native-engine/datafusion-ext-plans/src/common/batch_selection.rs +++ b/native-engine/datafusion-ext-plans/src/common/batch_selection.rs
@@ -41,16 +41,33 @@ take_batch_internal(batch, indices) } +pub fn take_cols<T: num::PrimInt>( + cols: &[ArrayRef], + indices: impl IntoIterator<Item = T>, +) -> Result<Vec<ArrayRef>> { + let indices: UInt32Array = + PrimitiveArray::from_iter(indices.into_iter().map(|idx| idx.to_u32().unwrap())); + take_cols_internal(cols, &indices) +} + +pub fn take_cols_opt<T: num::PrimInt>( + cols: &[ArrayRef], + indices: impl IntoIterator<Item = Option<T>>, +) -> Result<Vec<ArrayRef>> { + let indices: UInt32Array = PrimitiveArray::from_iter( + indices + .into_iter() + .map(|opt| opt.map(|idx| idx.to_u32().unwrap())), + ); + take_cols_internal(cols, &indices) +} + fn take_batch_internal(batch: RecordBatch, indices: UInt32Array) -> Result<RecordBatch> { let taken_num_batch_rows = indices.len(); let schema = batch.schema(); - let cols = batch.columns().to_vec(); - drop(batch); // we would like to release batch as soon as possible + let cols = batch.columns(); - let cols = cols - .into_iter() - .map(|c| Ok(arrow::compute::take(&c, &indices, None)?)) - .collect::<Result<_>>()?; + let cols = take_cols_internal(cols, &indices)?; drop(indices); let taken = RecordBatch::try_new_with_options( @@ -61,6 +78,14 @@ Ok(taken) } +fn take_cols_internal(cols: &[ArrayRef], indices: &UInt32Array) -> Result<Vec<ArrayRef>> { + let cols = cols + .into_iter() + .map(|c| Ok(arrow::compute::take(&c, indices, None)?)) + .collect::<Result<_>>()?; + Ok(cols) +} + pub fn interleave_batches( schema: SchemaRef, batches: &[RecordBatch],
diff --git a/native-engine/datafusion-ext-plans/src/common/output.rs b/native-engine/datafusion-ext-plans/src/common/output.rs index b1a0a28..d888026 100644 --- a/native-engine/datafusion-ext-plans/src/common/output.rs +++ b/native-engine/datafusion-ext-plans/src/common/output.rs
@@ -20,6 +20,7 @@ }; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use async_trait::async_trait; use blaze_jni_bridge::is_task_running; use datafusion::{ common::Result, @@ -221,3 +222,34 @@ WrappedRecordBatchSender::cancel_task(self); } } + +#[async_trait] +pub trait NextBatchWithTimer { + async fn next_batch( + &mut self, + stop_timer: Option<&mut ScopedTimerGuard<'_>>, + ) -> Result<Option<RecordBatch>>; +} + +#[async_trait] +impl NextBatchWithTimer for SendableRecordBatchStream { + async fn next_batch( + &mut self, + stop_timer: Option<&mut ScopedTimerGuard<'_>>, + ) -> Result<Option<RecordBatch>> { + struct StopScopedTimerGuard<'a, 'z>(&'a mut ScopedTimerGuard<'z>); + impl<'a, 'z> StopScopedTimerGuard<'a, 'z> { + fn new(timer: &'a mut ScopedTimerGuard<'z>) -> Self { + timer.stop(); + Self(timer) + } + } + impl Drop for StopScopedTimerGuard<'_, '_> { + fn drop(&mut self) { + self.0.restart(); + } + } + let _stop_timer = stop_timer.map(|timer| StopScopedTimerGuard::new(timer)); + self.next().await.transpose() + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs new file mode 100644 index 0000000..1af9de8 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs
@@ -0,0 +1,267 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +use arrow::array::{new_null_array, ArrayRef, RecordBatch}; +use async_trait::async_trait; +use bitvec::{bitvec, prelude::BitVec}; +use datafusion::{common::Result, physical_plan::metrics::Time}; + +use crate::{ + broadcast_join_exec::Joiner, + common::{batch_selection::take_cols, output::WrappedRecordBatchSender}, + joins::{ + bhj::{ + filter_joined_indices, + full_join::ProbeSide::{L, R}, + ProbeSide, + }, + join_hash_map::{join_create_hashes, JoinHashMap}, + JoinParams, + }, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + probe_side: ProbeSide, + probe_side_outer: bool, + build_side_outer: bool, +} + +impl JoinerParams { + const fn new(probe_side: ProbeSide, probe_side_outer: bool, build_side_outer: bool) -> Self { + Self { + probe_side, + probe_side_outer, + build_side_outer, + } + } +} + +const LEFT_PROBED_INNER: JoinerParams = JoinerParams::new(L, false, false); +const LEFT_PROBED_LEFT: JoinerParams = JoinerParams::new(L, true, false); +const LEFT_PROBED_RIGHT: JoinerParams = JoinerParams::new(L, false, true); +const LEFT_PROBED_OUTER: JoinerParams = JoinerParams::new(L, true, true); + +const RIGHT_PROBED_INNER: JoinerParams = JoinerParams::new(R, false, false); +const RIGHT_PROBED_LEFT: JoinerParams = JoinerParams::new(R, false, true); +const RIGHT_PROBED_RIGHT: JoinerParams = JoinerParams::new(R, true, false); +const RIGHT_PROBED_OUTER: JoinerParams = JoinerParams::new(R, true, true); + +pub type LProbedInnerJoiner = FullJoiner<LEFT_PROBED_INNER>; +pub type LProbedLeftJoiner = FullJoiner<LEFT_PROBED_LEFT>; +pub type LProbedRightJoiner = FullJoiner<LEFT_PROBED_RIGHT>; +pub type LProbedFullOuterJoiner = FullJoiner<LEFT_PROBED_OUTER>; +pub type RProbedInnerJoiner = FullJoiner<RIGHT_PROBED_INNER>; +pub type RProbedLeftJoiner = FullJoiner<RIGHT_PROBED_LEFT>; +pub type RProbedRightJoiner = FullJoiner<RIGHT_PROBED_RIGHT>; +pub type RProbedFullOuterJoiner = FullJoiner<RIGHT_PROBED_OUTER>; + +pub struct FullJoiner<const P: JoinerParams> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + map: Arc<JoinHashMap>, + map_joined: BitVec, + send_output_time: Time, + output_rows: AtomicUsize, +} + +impl<const P: JoinerParams> FullJoiner<P> { + pub fn new( + join_params: JoinParams, + map: Arc<JoinHashMap>, + output_sender: Arc<WrappedRecordBatchSender>, + ) -> Self { + let map_joined = bitvec![0; map.data_batch().num_rows()]; + Self { + join_params, + output_sender, + map, + map_joined, + send_output_time: Time::default(), + output_rows: AtomicUsize::new(0), + } + } + + fn create_probed_key_columns(&self, probed_batch: &RecordBatch) -> Result<Vec<ArrayRef>> { + let probed_key_exprs = match P.probe_side { + L => &self.join_params.left_keys, + R => &self.join_params.right_keys, + }; + let probed_key_columns: Vec<ArrayRef> = probed_key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(probed_batch)? + .into_array(probed_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(probed_key_columns) + } + + async fn flush(&self, probe_cols: Vec<ArrayRef>, build_cols: Vec<ArrayRef>) -> Result<()> { + let output_batch = RecordBatch::try_new( + self.join_params.output_schema.clone(), + match P.probe_side { + L => [probe_cols, build_cols].concat(), + R => [build_cols, probe_cols].concat(), + }, + )?; + self.output_rows.fetch_add(output_batch.num_rows(), Relaxed); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + Ok(()) + } + + async fn flush_hash_joined( + mut self: Pin<&mut Self>, + probed_batch: &RecordBatch, + probed_key_columns: &[ArrayRef], + probed_joined: &mut BitVec, + mut hash_joined_probe_indices: Vec<u32>, + mut hash_joined_build_indices: Vec<u32>, + ) -> Result<()> { + filter_joined_indices( + probed_key_columns, + self.map.key_columns(), + &mut hash_joined_probe_indices, + &mut hash_joined_build_indices, + )?; + let probe_indices = hash_joined_probe_indices; + let build_indices = hash_joined_build_indices; + + for &idx in &probe_indices { + probed_joined.set(idx as usize, true); + } + for &idx in &build_indices { + self.map_joined.set(idx as usize, true); + } + + let pcols = take_cols(probed_batch.columns(), probe_indices)?; + let bcols = take_cols(self.map.data_batch().columns(), build_indices)?; + self.flush(pcols, bcols).await?; + Ok(()) + } +} + +#[async_trait] +impl<const P: JoinerParams> Joiner for FullJoiner<P> { + async fn join(mut self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()> { + let mut hash_joined_probe_indices: Vec<u32> = vec![]; + let mut hash_joined_build_indices: Vec<u32> = vec![]; + let mut probed_joined = bitvec![0; probed_batch.num_rows()]; + + let probed_key_columns = self.create_probed_key_columns(&probed_batch)?; + let probed_hashes = join_create_hashes(probed_batch.num_rows(), &probed_key_columns)?; + + // join by hash code + for (row_idx, &hash) in probed_hashes.iter().enumerate() { + let mut maybe_joined = false; + if let Some(entries) = self.map.entry_indices(hash) { + for map_idx in entries { + hash_joined_probe_indices.push(row_idx as u32); + hash_joined_build_indices.push(map_idx); + } + maybe_joined = true; + } + + if maybe_joined && hash_joined_probe_indices.len() >= self.join_params.batch_size { + self.as_mut() + .flush_hash_joined( + &probed_batch, + &probed_key_columns, + &mut probed_joined, + std::mem::take(&mut hash_joined_probe_indices), + std::mem::take(&mut hash_joined_build_indices), + ) + .await?; + } + } + if !hash_joined_probe_indices.is_empty() { + self.as_mut() + .flush_hash_joined( + &probed_batch, + &probed_key_columns, + &mut probed_joined, + hash_joined_probe_indices, + hash_joined_build_indices, + ) + .await?; + } + + // output unjoined rows of probed side + if P.probe_side_outer { + let probed_unjoined_indices = probed_joined + .iter() + .enumerate() + .filter(|(_, joined)| !**joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + + let bcols = self + .map + .data_batch() + .columns() + .iter() + .map(|col| new_null_array(col.data_type(), probed_unjoined_indices.len())) + .collect::<Vec<_>>(); + let pcols = take_cols(probed_batch.columns(), probed_unjoined_indices)?; + self.as_mut().flush(pcols, bcols).await?; + } + Ok(()) + } + + async fn finish(mut self: Pin<&mut Self>) -> Result<()> { + // output unjoined rows of probed side + let map_joined = std::mem::take(&mut self.map_joined); + if P.build_side_outer { + let map_unjoined_indices = map_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| !joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + + let pschema = match P.probe_side { + L => &self.join_params.left_schema, + R => &self.join_params.right_schema, + }; + let pcols = pschema + .fields() + .iter() + .map(|field| new_null_array(field.data_type(), map_unjoined_indices.len())) + .collect::<Vec<_>>(); + let bcols = take_cols(self.map.data_batch().columns(), map_unjoined_indices)?; + self.as_mut().flush(pcols, bcols).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows.load(Relaxed) + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs new file mode 100644 index 0000000..57d934c --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs
@@ -0,0 +1,146 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::{ + array::*, + datatypes::{DataType, IntervalUnit, TimeUnit}, +}; +use datafusion::common::Result; +use datafusion_ext_commons::{df_execution_err, downcast_any}; + +pub mod full_join; +pub mod semi_join; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum ProbeSide { + L, + R, +} + +fn filter_joined_indices( + key_columns1: &[ArrayRef], + key_columns2: &[ArrayRef], + indices1: &mut Vec<u32>, + indices2: &mut Vec<u32>, +) -> Result<()> { + fn filter_one( + key_column1: &ArrayRef, + key_column2: &ArrayRef, + indices1: &mut Vec<u32>, + indices2: &mut Vec<u32>, + ) -> Result<()> { + macro_rules! filter_atomic { + ($cast_type:ty) => {{ + let col1 = downcast_any!(key_column1, $cast_type)?; + let col2 = downcast_any!(key_column2, $cast_type)?; + let mut valid_count = 0; + for i in 0..indices1.len() { + let idx1 = indices1[i] as usize; + let idx2 = indices2[i] as usize; + if col1.is_valid(idx1) && col2.is_valid(idx2) && { + let v1 = col1.value(idx1); + let v2 = col2.value(idx2); + v1 == v2 + } { + indices1[valid_count] = indices1[i]; + indices2[valid_count] = indices2[i]; + valid_count += 1; + } + } + indices1.truncate(valid_count); + indices2.truncate(valid_count); + }}; + } + + let dt1 = key_column1.data_type(); + let dt2 = key_column2.data_type(); + if dt1 != dt2 { + return df_execution_err!("join key data type not matched: {dt1:?} <-> {dt2:?}"); + } + match dt1 { + DataType::Null => { + indices1.clear(); + indices2.clear(); + } + DataType::Boolean => filter_atomic!(BooleanArray), + DataType::Int8 => filter_atomic!(Int8Array), + DataType::Int16 => filter_atomic!(Int16Array), + DataType::Int32 => filter_atomic!(Int32Array), + DataType::Int64 => filter_atomic!(Int64Array), + DataType::UInt8 => filter_atomic!(UInt8Array), + DataType::UInt16 => filter_atomic!(UInt16Array), + DataType::UInt32 => filter_atomic!(UInt32Array), + DataType::UInt64 => filter_atomic!(UInt64Array), + DataType::Float16 => filter_atomic!(Float16Array), + DataType::Float32 => filter_atomic!(Float32Array), + DataType::Float64 => filter_atomic!(Float64Array), + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => filter_atomic!(TimestampSecondArray), + TimeUnit::Millisecond => filter_atomic!(TimestampMillisecondArray), + TimeUnit::Microsecond => filter_atomic!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(TimestampNanosecondArray), + }, + DataType::Date32 => filter_atomic!(Date32Array), + DataType::Date64 => filter_atomic!(Date64Array), + DataType::Time32(unit) => match unit { + TimeUnit::Second => filter_atomic!(Time32SecondArray), + TimeUnit::Millisecond => filter_atomic!(Time32MillisecondArray), + TimeUnit::Microsecond => filter_atomic!(Time32MillisecondArray), + TimeUnit::Nanosecond => filter_atomic!(Time32MillisecondArray), + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => filter_atomic!(Time64MicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(Time64NanosecondArray), + _ => return df_execution_err!("unsupported time64 unit: {unit:?}"), + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => filter_atomic!(DurationSecondArray), + TimeUnit::Millisecond => filter_atomic!(DurationMillisecondArray), + TimeUnit::Microsecond => filter_atomic!(DurationMicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(DurationNanosecondArray), + }, + DataType::Interval(unit) => match unit { + IntervalUnit::YearMonth => filter_atomic!(IntervalYearMonthArray), + IntervalUnit::DayTime => filter_atomic!(IntervalDayTimeArray), + IntervalUnit::MonthDayNano => filter_atomic!(IntervalMonthDayNanoArray), + }, + DataType::Binary => filter_atomic!(BinaryArray), + DataType::FixedSizeBinary(_) => filter_atomic!(FixedSizeBinaryArray), + DataType::LargeBinary => filter_atomic!(LargeBinaryArray), + DataType::Utf8 => filter_atomic!(StringArray), + DataType::LargeUtf8 => filter_atomic!(LargeStringArray), + DataType::List(_) => filter_atomic!(ListArray), + DataType::FixedSizeList(..) => filter_atomic!(FixedSizeListArray), + DataType::LargeList(_) => filter_atomic!(LargeListArray), + DataType::Struct(_) => filter_joined_indices( + key_column1.as_struct().columns(), + key_column2.as_struct().columns(), + indices1, + indices2, + )?, + DataType::Decimal128(..) => filter_atomic!(Decimal128Array), + DataType::Decimal256(..) => filter_atomic!(Decimal256Array), + DataType::Map(..) => filter_atomic!(MapArray), + dt => { + return df_execution_err!("unsupported data type: {dt:?}"); + } + } + Ok(()) + } + + for (key_column1, key_column2) in key_columns1.iter().zip(key_columns2) { + filter_one(key_column1, key_column2, indices1, indices2)?; + } + Ok(()) +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs new file mode 100644 index 0000000..7e99224 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
@@ -0,0 +1,263 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +use arrow::array::{ArrayRef, BooleanArray, RecordBatch}; +use async_trait::async_trait; +use bitvec::{bitvec, prelude::BitVec}; +use datafusion::{common::Result, physical_plan::metrics::Time}; + +use crate::{ + broadcast_join_exec::Joiner, + common::{batch_selection::take_cols, output::WrappedRecordBatchSender}, + joins::{ + bhj::{ + filter_joined_indices, + semi_join::{ + ProbeSide::{L, R}, + SemiMode::{Anti, Existence, Semi}, + }, + ProbeSide, + }, + join_hash_map::{join_create_hashes, JoinHashMap}, + JoinParams, + }, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum SemiMode { + Semi, + Anti, + Existence, +} + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + probe_side: ProbeSide, + probe_is_join_side: bool, + mode: SemiMode, +} + +impl JoinerParams { + const fn new(probe_side: ProbeSide, probe_is_join_side: bool, mode: SemiMode) -> Self { + Self { + probe_side, + probe_is_join_side, + mode, + } + } +} + +const LEFT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(L, true, Semi); +const LEFT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(L, true, Anti); +const LEFT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(L, false, Semi); +const LEFT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(L, false, Anti); +const LEFT_PROBED_EXISTENCE: JoinerParams = JoinerParams::new(L, true, Existence); +const RIGHT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(R, false, Semi); +const RIGHT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(R, false, Anti); +const RIGHT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true, Semi); +const RIGHT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(R, true, Anti); +const RIGHT_PROBED_EXISTENCE: JoinerParams = JoinerParams::new(R, false, Existence); + +pub type LProbedLeftSemiJoiner = SemiJoiner<LEFT_PROBED_LEFT_SEMI>; +pub type LProbedLeftAntiJoiner = SemiJoiner<LEFT_PROBED_LEFT_ANTI>; +pub type LProbedRightSemiJoiner = SemiJoiner<LEFT_PROBED_RIGHT_SEMI>; +pub type LProbedRightAntiJoiner = SemiJoiner<LEFT_PROBED_RIGHT_ANTI>; +pub type LProbedExistenceJoiner = SemiJoiner<LEFT_PROBED_EXISTENCE>; +pub type RProbedLeftSemiJoiner = SemiJoiner<RIGHT_PROBED_LEFT_SEMI>; +pub type RProbedLeftAntiJoiner = SemiJoiner<RIGHT_PROBED_LEFT_ANTI>; +pub type RProbedRightSemiJoiner = SemiJoiner<RIGHT_PROBED_RIGHT_SEMI>; +pub type RProbedRightAntiJoiner = SemiJoiner<RIGHT_PROBED_RIGHT_ANTI>; +pub type RProbedExistenceJoiner = SemiJoiner<RIGHT_PROBED_EXISTENCE>; + +pub struct SemiJoiner<const P: JoinerParams> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + map_joined: BitVec, + map: Arc<JoinHashMap>, + send_output_time: Time, + output_rows: AtomicUsize, +} + +impl<const P: JoinerParams> SemiJoiner<P> { + pub fn new( + join_params: JoinParams, + map: Arc<JoinHashMap>, + output_sender: Arc<WrappedRecordBatchSender>, + ) -> Self { + let map_joined = bitvec![0; map.data_batch().num_rows()]; + Self { + join_params, + output_sender, + map, + map_joined, + send_output_time: Time::new(), + output_rows: AtomicUsize::new(0), + } + } + + fn create_probed_key_columns(&self, probed_batch: &RecordBatch) -> Result<Vec<ArrayRef>> { + let probed_key_exprs = match P.probe_side { + L => &self.join_params.left_keys, + R => &self.join_params.right_keys, + }; + let probed_key_columns: Vec<ArrayRef> = probed_key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(probed_batch)? + .into_array(probed_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(probed_key_columns) + } + + async fn flush(&self, cols: Vec<ArrayRef>) -> Result<()> { + let output_batch = RecordBatch::try_new(self.join_params.output_schema.clone(), cols)?; + self.output_rows.fetch_add(output_batch.num_rows(), Relaxed); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + Ok(()) + } + + fn flush_hash_joined( + mut self: Pin<&mut Self>, + probed_key_columns: &[ArrayRef], + probed_joined: &mut BitVec, + mut hash_joined_probe_indices: Vec<u32>, + mut hash_joined_build_indices: Vec<u32>, + ) -> Result<()> { + filter_joined_indices( + probed_key_columns, + self.map.key_columns(), + &mut hash_joined_probe_indices, + &mut hash_joined_build_indices, + )?; + let probe_indices = hash_joined_probe_indices; + let build_indices = hash_joined_build_indices; + + for &idx in &probe_indices { + probed_joined.set(idx as usize, true); + } + for &idx in &build_indices { + self.map_joined.set(idx as usize, true); + } + Ok(()) + } +} + +#[async_trait] +impl<const P: JoinerParams> Joiner for SemiJoiner<P> { + async fn join(mut self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()> { + let mut hash_joined_probe_indices: Vec<u32> = vec![]; + let mut hash_joined_build_indices: Vec<u32> = vec![]; + let mut probed_joined = bitvec![0; probed_batch.num_rows()]; + + let probed_key_columns = self.create_probed_key_columns(&probed_batch)?; + let probed_hashes = join_create_hashes(probed_batch.num_rows(), &probed_key_columns)?; + + // join by hash code + for (row_idx, &hash) in probed_hashes.iter().enumerate() { + let mut maybe_joined = false; + if let Some(entries) = self.map.entry_indices(hash) { + for map_idx in entries { + hash_joined_probe_indices.push(row_idx as u32); + hash_joined_build_indices.push(map_idx); + } + maybe_joined = true; + } + + if maybe_joined && hash_joined_probe_indices.len() >= self.join_params.batch_size { + self.as_mut().flush_hash_joined( + &probed_key_columns, + &mut probed_joined, + std::mem::take(&mut hash_joined_probe_indices), + std::mem::take(&mut hash_joined_build_indices), + )?; + } + } + if !hash_joined_probe_indices.is_empty() { + self.as_mut().flush_hash_joined( + &probed_key_columns, + &mut probed_joined, + hash_joined_probe_indices, + hash_joined_build_indices, + )?; + } + + if P.probe_is_join_side { + let pcols = match P.mode { + Semi | Anti => { + let probed_indices = probed_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| (P.mode == Semi) ^ !joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + take_cols(probed_batch.columns(), probed_indices)? + } + Existence => { + let exists_col = Arc::new(BooleanArray::from( + probed_joined.into_iter().collect::<Vec<_>>(), + )); + [probed_batch.columns().to_vec(), vec![exists_col]].concat() + } + }; + self.as_mut().flush(pcols).await?; + } + Ok(()) + } + + async fn finish(mut self: Pin<&mut Self>) -> Result<()> { + if !P.probe_is_join_side { + let map_joined = std::mem::take(&mut self.map_joined); + let pcols = match P.mode { + Semi | Anti => { + let map_indices = map_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| (P.mode == Semi) ^ !joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + take_cols(self.map.data_batch().columns(), map_indices)? + } + Existence => { + let exists_col = Arc::new(BooleanArray::from( + map_joined.into_iter().collect::<Vec<_>>(), + )); + [self.map.data_batch().columns().to_vec(), vec![exists_col]].concat() + } + }; + self.as_mut().flush(pcols).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows.load(Relaxed) + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs new file mode 100644 index 0000000..dbb4d2e --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs
@@ -0,0 +1,328 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + io::{Cursor, Read, Write}, + slice::{from_raw_parts, from_raw_parts_mut}, + sync::Arc, +}; + +use arrow::{ + array::{ArrayRef, AsArray, BinaryBuilder, RecordBatch}, + datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}, +}; +use byteorder::{NativeEndian, ReadBytesExt, WriteBytesExt}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::spark_hash::create_hashes; +use hashbrown::HashMap; +use itertools::Itertools; +use once_cell::sync::OnceCell; + +use crate::common::batch_selection::take_batch; + +pub struct Table { + entry_offsets: Vec<u32>, + entry_lens: Vec<u32>, + item_indices: Vec<u32>, + item_hashes: Vec<u32>, +} + +impl Table { + pub fn new_empty() -> Self { + let num_entries = Self::num_entries_of_rows(0); + Self { + entry_offsets: vec![0; num_entries], + entry_lens: vec![0; num_entries], + item_indices: vec![], + item_hashes: vec![], + } + } + + pub fn try_from_key_columns( + num_rows: usize, + data_batch: RecordBatch, + key_columns: &[ArrayRef], + ) -> Result<(Self, RecordBatch)> { + // returns the new data batch sorted by hashes + + assert!( + num_rows < 1073741824, + "join hash table: number of rows exceeded 2^30: {num_rows}" + ); + + let num_entries = Self::num_entries_of_rows(num_rows) as u32; + let item_hashes = join_create_hashes(num_rows, &key_columns)?; + + // sort record batch by hashes for better compression and data locality + let (indices, item_hashes): (Vec<usize>, Vec<u32>) = item_hashes + .into_iter() + .enumerate() + .sorted_unstable_by_key(|(_idx, hash)| *hash) + .unzip(); + let data_batch = take_batch(data_batch, indices)?; + + let mut entries_to_row_indices: HashMap<u32, Vec<u32>> = HashMap::new(); + for (row_idx, hash) in item_hashes.iter().enumerate() { + let entry = hash % num_entries; + entries_to_row_indices + .entry(entry) + .or_default() + .push(row_idx as u32); + } + + let mut entry_offsets = Vec::with_capacity(num_entries as usize); + let mut entry_lens = Vec::with_capacity(num_entries as usize); + let mut item_indices = Vec::with_capacity(num_rows); + for entry in 0..num_entries { + match entries_to_row_indices.get(&entry) { + Some(row_indices) => { + entry_offsets.push(item_indices.len() as u32); + entry_lens.push(row_indices.len() as u32); + item_indices.extend_from_slice(row_indices); + } + None => { + entry_offsets.push(item_indices.len() as u32); + entry_lens.push(0); + } + } + } + let new = Self { + entry_offsets, + entry_lens, + item_indices, + item_hashes, + }; + Ok((new, data_batch)) + } + + pub fn try_from_raw_bytes(raw_bytes: &[u8]) -> Result<Self> { + let mut cursor = Cursor::new(raw_bytes); + let num_rows = cursor.read_u32::<NativeEndian>()? as usize; + let num_entries = Self::num_entries_of_rows(num_rows); + + let mut new = Self { + entry_offsets: vec![0; num_entries], + entry_lens: vec![0; num_entries], + item_indices: vec![0; num_rows], + item_hashes: vec![0; num_rows], + }; + + unsafe { + // safety: read integer arrays as raw bytes + cursor.read_exact(from_raw_parts_mut( + new.entry_offsets.as_mut_ptr() as *mut u8, + num_entries * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.entry_lens.as_mut_ptr() as *mut u8, + num_entries * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.item_indices.as_mut_ptr() as *mut u8, + num_rows * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.item_hashes.as_mut_ptr() as *mut u8, + num_rows * 4, + ))?; + } + Ok(new) + } + + pub fn try_into_raw_bytes(self) -> Result<Vec<u8>> { + let num_entries = self.entry_offsets.len(); + let num_rows = self.item_indices.len(); + let mut raw_bytes = Vec::with_capacity(num_entries * 8 + num_rows * 4 + 4); + + raw_bytes.write_u32::<NativeEndian>(num_rows as u32)?; + unsafe { + // safety: write integer arrays as raw bytes + raw_bytes.write_all(from_raw_parts( + self.entry_offsets.as_ptr() as *const u8, + num_entries * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.entry_lens.as_ptr() as *const u8, + num_entries * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.item_indices.as_ptr() as *const u8, + num_rows * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.item_hashes.as_ptr() as *const u8, + num_rows * 4, + ))?; + } + Ok(raw_bytes) + } + + pub fn entry<'a>(&'a self, hash: u32) -> Option<impl Iterator<Item = u32> + 'a> { + let entry = hash % (self.entry_offsets.len() as u32); + let len = self.entry_lens[entry as usize] as usize; + if len > 0 { + let offset = self.entry_offsets[entry as usize] as usize; + Some( + self.item_indices[offset..][..len] + .iter() + .cloned() + .filter(move |&idx| self.item_hashes[idx as usize] == hash), + ) + } else { + None + } + } + + fn num_entries_of_rows(num_rows: usize) -> usize { + num_rows * 5 + 1 + } +} + +pub struct JoinHashMap { + data_batch: RecordBatch, + key_columns: Vec<ArrayRef>, + table: Table, +} + +impl JoinHashMap { + pub fn try_from_data_batch( + data_batch: RecordBatch, + key_exprs: &[PhysicalExprRef], + ) -> Result<JoinHashMap> { + let key_columns: Vec<ArrayRef> = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + + let (table, data_batch) = + Table::try_from_key_columns(data_batch.num_rows(), data_batch, &key_columns)?; + Ok(JoinHashMap { + data_batch, + key_columns, + table, + }) + } + + pub fn try_from_hash_map_batch( + hash_map_batch: RecordBatch, + key_exprs: &[PhysicalExprRef], + ) -> Result<Self> { + let mut data_batch = hash_map_batch.clone(); + let table = Table::try_from_raw_bytes( + data_batch + .remove_column(data_batch.num_columns() - 1) + .as_binary::<i32>() + .value(0), + )?; + let key_columns: Vec<ArrayRef> = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(Self { + data_batch, + key_columns, + table, + }) + } + + pub fn try_new_empty( + hash_map_schema: SchemaRef, + key_exprs: &[PhysicalExprRef], + ) -> Result<Self> { + let table = Table::new_empty(); + let data_batch = RecordBatch::new_empty(hash_map_schema); + let key_columns: Vec<ArrayRef> = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(Self { + data_batch, + key_columns, + table, + }) + } + + pub fn data_schema(&self) -> SchemaRef { + self.data_batch().schema() + } + + pub fn data_batch(&self) -> &RecordBatch { + &self.data_batch + } + + pub fn key_columns(&self) -> &[ArrayRef] { + &self.key_columns + } + + pub fn entry_indices<'a>(&'a self, hash: u32) -> Option<impl Iterator<Item = u32> + 'a> { + self.table.entry(hash) + } + + pub fn into_hash_map_batch(self) -> Result<RecordBatch> { + let schema = join_hash_map_schema(&self.data_batch.schema()); + if self.data_batch.num_rows() == 0 { + return Ok(RecordBatch::new_empty(schema)); + } + let mut table_col_builder = BinaryBuilder::new(); + table_col_builder.append_value(&self.table.try_into_raw_bytes()?); + for _ in 1..self.data_batch.num_rows() { + table_col_builder.append_null(); + } + let table_col: ArrayRef = Arc::new(table_col_builder.finish()); + Ok(RecordBatch::try_new( + schema, + vec![self.data_batch.columns().to_vec(), vec![table_col]].concat(), + )?) + } +} + +#[inline] +pub fn join_hash_map_schema(data_schema: &SchemaRef) -> SchemaRef { + Arc::new(Schema::new( + data_schema + .fields() + .iter() + .map(|field| Arc::new(field.as_ref().clone().with_nullable(true))) + .chain(std::iter::once(join_table_field())) + .collect::<Vec<_>>(), + )) +} + +#[inline] +pub fn join_create_hashes(num_rows: usize, key_columns: &[ArrayRef]) -> Result<Vec<u32>> { + const JOIN_HASH_RANDOM_SEED: u32 = 0x90ec4058; + let mut hashes = vec![JOIN_HASH_RANDOM_SEED; num_rows]; + create_hashes(key_columns, &mut hashes)?; + Ok(hashes) +} + +#[inline] +fn join_table_field() -> FieldRef { + static BHJ_KEY_FIELD: OnceCell<FieldRef> = OnceCell::new(); + BHJ_KEY_FIELD + .get_or_init(|| Arc::new(Field::new("~TABLE", DataType::Binary, true))) + .clone() +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/join_utils.rs b/native-engine/datafusion-ext-plans/src/joins/join_utils.rs new file mode 100644 index 0000000..076cfa1 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/join_utils.rs
@@ -0,0 +1,64 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use datafusion::common::{DataFusionError, Result}; +use datafusion_ext_commons::df_execution_err; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + LeftAnti, + RightAnti, + LeftSemi, + RightSemi, + Existence, +} + +impl TryFrom<JoinType> for datafusion::prelude::JoinType { + type Error = DataFusionError; + + fn try_from(value: JoinType) -> Result<Self> { + match value { + JoinType::Inner => Ok(datafusion::prelude::JoinType::Inner), + JoinType::Left => Ok(datafusion::prelude::JoinType::Left), + JoinType::Right => Ok(datafusion::prelude::JoinType::Right), + JoinType::Full => Ok(datafusion::prelude::JoinType::Full), + JoinType::LeftAnti => Ok(datafusion::prelude::JoinType::LeftAnti), + JoinType::RightAnti => Ok(datafusion::prelude::JoinType::RightAnti), + JoinType::LeftSemi => Ok(datafusion::prelude::JoinType::LeftSemi), + JoinType::RightSemi => Ok(datafusion::prelude::JoinType::RightSemi), + other => df_execution_err!("unsupported join type: {other:?}"), + } + } +} + +impl TryFrom<datafusion::prelude::JoinType> for JoinType { + type Error = DataFusionError; + + fn try_from(value: datafusion::prelude::JoinType) -> Result<Self> { + match value { + datafusion::prelude::JoinType::Inner => Ok(JoinType::Inner), + datafusion::prelude::JoinType::Left => Ok(JoinType::Left), + datafusion::prelude::JoinType::Right => Ok(JoinType::Right), + datafusion::prelude::JoinType::Full => Ok(JoinType::Full), + datafusion::prelude::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + datafusion::prelude::JoinType::RightAnti => Ok(JoinType::RightAnti), + datafusion::prelude::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + datafusion::prelude::JoinType::RightSemi => Ok(JoinType::RightSemi), + } + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/mod.rs b/native-engine/datafusion-ext-plans/src/joins/mod.rs new file mode 100644 index 0000000..243fc60 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/mod.rs
@@ -0,0 +1,46 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::{ + compute::SortOptions, + datatypes::{DataType, SchemaRef}, +}; +use datafusion::physical_expr::PhysicalExprRef; + +use crate::joins::{join_utils::JoinType, stream_cursor::StreamCursor}; + +pub mod join_hash_map; +pub mod join_utils; +pub mod stream_cursor; + +// join implementations +pub mod bhj; +pub mod smj; +mod test; + +#[derive(Debug, Clone)] +pub struct JoinParams { + pub join_type: JoinType, + pub left_schema: SchemaRef, + pub right_schema: SchemaRef, + pub output_schema: SchemaRef, + pub left_keys: Vec<PhysicalExprRef>, + pub right_keys: Vec<PhysicalExprRef>, + pub key_data_types: Vec<DataType>, + pub sort_options: Vec<SortOptions>, + pub batch_size: usize, +} + +pub type Idx = (usize, usize); +pub type StreamCursors = (StreamCursor, StreamCursor);
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs new file mode 100644 index 0000000..194ede1 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs
@@ -0,0 +1,171 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{Idx, JoinParams, StreamCursors}, + sort_merge_join_exec::Joiner, +}; + +pub struct ExistenceJoiner { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + indices: Vec<Idx>, + exists: Vec<bool>, + send_output_time: Time, + output_rows: usize, +} + +impl ExistenceJoiner { + pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self { + Self { + join_params, + output_sender, + indices: vec![], + exists: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.indices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_idx) = self.indices.first() { + if first_idx.0 < curs.0.cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let indices = std::mem::take(&mut self.indices); + let num_rows = indices.len(); + let cols = interleave_batches(curs.0.batch_schema.clone(), &curs.0.batches, &indices)?; + + let exists = std::mem::take(&mut self.exists); + let exists_col: ArrayRef = Arc::new(arrow::array::BooleanArray::from(exists)); + + let output_batch = RecordBatch::try_new_with_options( + self.join_params.output_schema.clone(), + [cols.columns().to_vec(), vec![exists_col]].concat(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl Joiner for ExistenceJoiner { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + + match compare_cursor!(curs) { + Ordering::Less => { + self.indices.push(curs.0.cur_idx); + self.exists.push(false); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx)); + } + Ordering::Greater => { + cur_forward!(curs.1); + curs.1 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.1.cur_idx)); + } + Ordering::Equal => { + loop { + self.indices.push(lidx); + self.exists.push(true); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&lidx)); + + if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) { + lidx = curs.0.cur_idx; + continue; + } + break; + } + + // skip all right equal rows + loop { + cur_forward!(curs.1); + curs.1.set_min_reserved_idx(ridx); + + if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) { + ridx = curs.1.cur_idx; + continue; + } + break; + } + } + } + } + + while !curs.0.finished { + self.indices.push(curs.0.cur_idx); + self.exists.push(false); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx)); + } + if !self.indices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs new file mode 100644 index 0000000..191b6a7 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
@@ -0,0 +1,240 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; +use smallvec::{smallvec, SmallVec}; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{Idx, JoinParams, StreamCursors}, + sort_merge_join_exec::Joiner, +}; + +pub struct FullJoiner<const L_OUTER: bool, const R_OUTER: bool> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + lindices: Vec<Idx>, + rindices: Vec<Idx>, + send_output_time: Time, + output_rows: usize, +} + +pub type InnerJoiner = FullJoiner<false, false>; +pub type LeftOuterJoiner = FullJoiner<true, false>; +pub type RightOuterJoiner = FullJoiner<false, true>; +pub type FullOuterJoiner = FullJoiner<true, true>; + +impl<const L_OUTER: bool, const R_OUTER: bool> FullJoiner<L_OUTER, R_OUTER> { + pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self { + Self { + join_params, + output_sender, + lindices: vec![], + rindices: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.lindices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_lidx) = self.lindices.first() { + if first_lidx.0 < curs.0.cur_idx.0 { + return true; + } + } + if let Some(first_ridx) = self.rindices.first() { + if first_ridx.0 < curs.1.cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let lindices = std::mem::take(&mut self.lindices); + let rindices = std::mem::take(&mut self.rindices); + let num_rows = lindices.len(); + assert_eq!(lindices.len(), rindices.len()); + + let lcols = interleave_batches(curs.0.batch_schema.clone(), &curs.0.batches, &lindices)?; + let rcols = interleave_batches(curs.1.batch_schema.clone(), &curs.1.batches, &rindices)?; + let output_batch = RecordBatch::try_new_with_options( + self.join_params.output_schema.clone(), + [lcols.columns(), rcols.columns()].concat(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl<const L_OUTER: bool, const R_OUTER: bool> Joiner for FullJoiner<L_OUTER, R_OUTER> { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + match compare_cursor!(curs) { + Ordering::Less => { + if L_OUTER { + self.lindices.push(lidx); + self.rindices.push(Idx::default()); + } + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&lidx)); + } + Ordering::Greater => { + if R_OUTER { + self.lindices.push(Idx::default()); + self.rindices.push(ridx); + } + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&ridx)); + } + Ordering::Equal => { + cur_forward!(curs.0); + cur_forward!(curs.1); + self.lindices.push(lidx); + self.rindices.push(ridx); + + let mut equal_lindices: SmallVec<[Idx; 16]> = smallvec![lidx]; + let mut equal_rindices: SmallVec<[Idx; 16]> = smallvec![ridx]; + let mut last_lidx = lidx; + let mut last_ridx = ridx; + lidx = curs.0.cur_idx; + ridx = curs.1.cur_idx; + let mut l_equal = !curs.0.finished && curs.0.key(lidx) == curs.0.key(last_lidx); + let mut r_equal = !curs.1.finished && curs.1.key(ridx) == curs.1.key(last_ridx); + + while l_equal || r_equal { + if l_equal { + for &ridx in &equal_rindices { + self.lindices.push(lidx); + self.rindices.push(ridx); + } + if r_equal { + equal_lindices.push(lidx); + } + cur_forward!(curs.0); + last_lidx = lidx; + lidx = curs.0.cur_idx; + } else { + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&last_ridx)); + } + + if r_equal { + for &lidx in &equal_lindices { + self.lindices.push(lidx); + self.rindices.push(ridx); + } + if l_equal { + equal_rindices.push(ridx); + } + cur_forward!(curs.1); + last_ridx = ridx; + ridx = curs.1.cur_idx; + } else { + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&last_lidx)); + } + + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + l_equal = l_equal + && !curs.0.finished + && curs.0.key(lidx) == curs.0.key(last_lidx); + r_equal = r_equal + && !curs.1.finished + && curs.1.key(ridx) == curs.1.key(last_ridx); + } + + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&curs.0.cur_idx)); + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&curs.1.cur_idx)); + } + } + } + + // at least one side is finished, consume the other side if it is an outer side + while L_OUTER && !curs.0.finished { + let lidx = curs.0.cur_idx; + self.lindices.push(lidx); + self.rindices.push(Idx::default()); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&lidx)); + } + while R_OUTER && !curs.1.finished { + let ridx = curs.1.cur_idx; + self.lindices.push(Idx::default()); + self.rindices.push(ridx); + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&ridx)); + } + if !self.lindices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs new file mode 100644 index 0000000..8bcdadf --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs
@@ -0,0 +1,17 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod existence_join; +pub mod full_join; +pub mod semi_join;
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs new file mode 100644 index 0000000..d7b5f87 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs
@@ -0,0 +1,243 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{ + smj::semi_join::SemiJoinSide::{L, R}, + Idx, JoinParams, StreamCursors, + }, + sort_merge_join_exec::Joiner, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum SemiJoinSide { + L, + R, +} + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + join_side: SemiJoinSide, + semi: bool, +} + +impl JoinerParams { + const fn new(join_side: SemiJoinSide, semi: bool) -> Self { + Self { join_side, semi } + } +} +pub struct SemiJoiner<const P: JoinerParams> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + indices: Vec<Idx>, + send_output_time: Time, + output_rows: usize, +} + +const LEFT_SEMI: JoinerParams = JoinerParams::new(L, true); +const LEFT_ANTI: JoinerParams = JoinerParams::new(L, false); +const RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true); +const RIGHT_ANTI: JoinerParams = JoinerParams::new(R, false); + +pub type LeftSemiJoiner = SemiJoiner<LEFT_SEMI>; +pub type LeftAntiJoiner = SemiJoiner<LEFT_ANTI>; +pub type RightSemiJoiner = SemiJoiner<RIGHT_SEMI>; +pub type RightAntiJoiner = SemiJoiner<RIGHT_ANTI>; + +impl<const P: JoinerParams> SemiJoiner<P> { + pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self { + Self { + join_params, + output_sender, + indices: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.indices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_idx) = self.indices.first() { + let cur_idx = match P.join_side { + L => curs.0.cur_idx, + R => curs.1.cur_idx, + }; + if first_idx.0 < cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let indices = std::mem::take(&mut self.indices); + let num_rows = indices.len(); + let cols = match P.join_side { + L => interleave_batches(curs.0.batch_schema.clone(), &curs.0.batches, &indices)?, + R => interleave_batches(curs.1.batch_schema.clone(), &curs.1.batches, &indices)?, + }; + let output_batch = RecordBatch::try_new_with_options( + self.join_params.output_schema.clone(), + cols.columns().to_vec(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl<const P: JoinerParams> Joiner for SemiJoiner<P> { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + + match compare_cursor!(curs) { + Ordering::Less => { + if P.join_side == L && !P.semi { + self.indices.push(lidx); + } + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + } + Ordering::Greater => { + if P.join_side == R && !P.semi { + self.indices.push(ridx); + } + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + } + Ordering::Equal => { + // output/skip left equal rows + loop { + if P.join_side == L && P.semi { + self.indices.push(lidx); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + } + cur_forward!(curs.0); + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + + if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) { + lidx = curs.0.cur_idx; + continue; + } + break; + } + + // output/skip right equal rows + loop { + if P.join_side == R && P.semi { + self.indices.push(ridx); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + } + cur_forward!(curs.1); + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + + if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) { + ridx = curs.1.cur_idx; + continue; + } + break; + } + } + } + } + + // at least one side is finished, consume the other side if it is an anti side + if !P.semi { + while P.join_side == L && !P.semi && !curs.0.finished { + let lidx = curs.0.cur_idx; + self.indices.push(lidx); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + } + while P.join_side == R && !P.semi && !curs.1.finished { + let ridx = curs.1.cur_idx; + self.indices.push(ridx); + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + } + } + if !self.indices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs new file mode 100644 index 0000000..81a1317 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs
@@ -0,0 +1,224 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::RecordBatch, + buffer::NullBuffer, + datatypes::{Schema, SchemaRef}, + row::{Row, RowConverter, Rows, SortField}, +}; +use datafusion::{ + common::{JoinSide, Result}, + execution::SendableRecordBatchStream, + physical_expr::PhysicalExprRef, + physical_plan::metrics::Time, +}; +use datafusion_ext_commons::array_size::ArraySize; +use futures::{Future, StreamExt}; +use parking_lot::Mutex; + +use crate::{ + common::batch_selection::take_batch_opt, + joins::{Idx, JoinParams}, +}; + +pub struct StreamCursor { + stream: SendableRecordBatchStream, + key_converter: Arc<Mutex<RowConverter>>, + key_exprs: Vec<PhysicalExprRef>, + poll_time: Time, + + // IMPORTANT: + // batches/rows/null_buffers always contains a `null batch` in the front + pub batch_schema: SchemaRef, // stream nullable schema + pub batches: Vec<RecordBatch>, + pub cur_idx: Idx, + min_reserved_idx: Idx, + keys: Vec<Arc<Rows>>, + key_has_nulls: Vec<Option<NullBuffer>>, + num_null_batches: usize, + mem_size: usize, + pub finished: bool, +} + +impl StreamCursor { + pub fn try_new( + stream: SendableRecordBatchStream, + join_params: &JoinParams, + join_side: JoinSide, + ) -> Result<Self> { + let key_converter = Arc::new(Mutex::new(RowConverter::new( + join_params + .key_data_types + .iter() + .cloned() + .zip(&join_params.sort_options) + .map(|(dt, options)| SortField::new_with_options(dt, *options)) + .collect(), + )?)); + let key_exprs = match join_side { + JoinSide::Left => join_params.left_keys.clone(), + JoinSide::Right => join_params.right_keys.clone(), + }; + + let empty_batch = RecordBatch::new_empty(Arc::new(Schema::new( + stream + .schema() + .fields() + .iter() + .map(|f| f.as_ref().clone().with_nullable(true)) + .collect::<Vec<_>>(), + ))); + let empty_keys = Arc::new( + key_converter.lock().convert_columns( + &key_exprs + .iter() + .map(|key| Ok(key.evaluate(&empty_batch)?.into_array(0)?)) + .collect::<Result<Vec<_>>>()?, + )?, + ); + let empty_batch_schema = empty_batch.schema(); + let null_batch = take_batch_opt(empty_batch, [Option::<usize>::None])?; + let null_nb = NullBuffer::new_null(1); + + Ok(Self { + stream, + key_exprs, + key_converter, + poll_time: Time::new(), + batch_schema: empty_batch_schema, + batches: vec![null_batch], + cur_idx: (0, 0), + min_reserved_idx: (0, 0), + keys: vec![empty_keys], + key_has_nulls: vec![Some(null_nb)], + num_null_batches: 1, + mem_size: 0, + finished: false, + }) + } + + pub fn next(&mut self) -> Option<impl Future<Output = Result<()>> + '_> { + self.cur_idx.1 += 1; + if self.cur_idx.1 >= self.batches[self.cur_idx.0].num_rows() { + self.cur_idx.0 += 1; + self.cur_idx.1 = 0; + } + + let should_load_next_batch = self.cur_idx.0 >= self.batches.len(); + if should_load_next_batch { + Some(async move { + while let Some(batch) = { + let timer = self.poll_time.timer(); + let batch = self.stream.next().await.transpose()?; + drop(timer); + batch + } { + if batch.num_rows() == 0 { + continue; + } + let key_columns = self + .key_exprs + .iter() + .map(|key| Ok(key.evaluate(&batch)?.into_array(batch.num_rows())?)) + .collect::<Result<Vec<_>>>()?; + let key_has_nulls = key_columns + .iter() + .map(|c| c.nulls().cloned()) + .reduce(|lhs, rhs| NullBuffer::union(lhs.as_ref(), rhs.as_ref())) + .unwrap_or(None); + let keys = Arc::new(self.key_converter.lock().convert_columns(&key_columns)?); + + self.mem_size += batch.get_array_mem_size(); + self.mem_size += key_has_nulls + .as_ref() + .map(|nb| nb.buffer().len()) + .unwrap_or_default(); + self.mem_size += keys.size(); + + self.batches.push(batch); + self.key_has_nulls.push(key_has_nulls); + self.keys.push(keys); + + // fill out-dated batches with null batches + if self.num_null_batches < self.min_reserved_idx.0 { + for i in self.num_null_batches..self.min_reserved_idx.0 { + self.mem_size -= self.batches[i].get_array_mem_size(); + self.mem_size -= self.key_has_nulls[i] + .as_ref() + .map(|nb| nb.buffer().len()) + .unwrap_or_default(); + self.mem_size -= self.keys[i].size(); + + self.batches[i] = self.batches[0].clone(); + self.keys[i] = self.keys[0].clone(); + self.key_has_nulls[i] = self.key_has_nulls[0].clone(); + self.num_null_batches += 1; + } + } + return Ok(()); + } + self.finished = true; + return Ok(()); + }) + } else { + None + } + } + + #[inline] + pub fn is_null_key(&self, idx: Idx) -> bool { + self.key_has_nulls[idx.0] + .as_ref() + .map(|nb| nb.is_null(idx.1)) + .unwrap_or(false) + } + + #[inline] + pub fn key<'a>(&'a self, idx: Idx) -> Row<'a> { + let keys = &self.keys[idx.0]; + keys.row(idx.1) + } + + #[inline] + pub fn num_buffered_batches(&self) -> usize { + self.batches.len() - self.num_null_batches + } + + #[inline] + pub fn mem_size(&self) -> usize { + self.mem_size + } + + #[inline] + pub fn set_min_reserved_idx(&mut self, idx: Idx) { + self.min_reserved_idx = idx; + } + + #[inline] + pub fn total_poll_time(&self) -> usize { + self.poll_time.value() + } +} + +#[macro_export] +macro_rules! cur_forward { + ($cur:expr) => {{ + if let Some(fut) = $cur.next() { + fut.await?; + } + }}; +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs b/native-engine/datafusion-ext-plans/src/joins/test.rs new file mode 100644 index 0000000..e0826e7 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/test.rs
@@ -0,0 +1,947 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + self, + array::*, + compute::SortOptions, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_sorted_eq, + common::JoinSide, + error::Result, + physical_expr::expressions::Column, + physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + use TestType::*; + + use crate::{ + broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, + broadcast_join_exec::BroadcastJoinExec, + joins::join_utils::{JoinType, JoinType::*}, + sort_merge_join_exec::SortMergeJoinExec, + }; + + #[derive(Clone, Copy)] + enum TestType { + SMJ, + BHJLeftProbed, + BHJRightProbed, + } + + fn columns(schema: &Schema) -> Vec<String> { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table_i32( + a: (&str, &Vec<i32>), + b: (&str, &Vec<i32>), + c: (&str, &Vec<i32>), + ) -> RecordBatch { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap() + } + + fn build_table( + a: (&str, &Vec<i32>), + b: (&str, &Vec<i32>), + c: (&str, &Vec<i32>), + ) -> Arc<dyn ExecutionPlan> { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> { + let schema = batches.first().unwrap().schema(); + Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap()) + } + + fn build_date_table( + a: (&str, &Vec<i32>), + b: (&str, &Vec<i32>), + c: (&str, &Vec<i32>), + ) -> Arc<dyn ExecutionPlan> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_date64_table( + a: (&str, &Vec<i64>), + b: (&str, &Vec<i64>), + c: (&str, &Vec<i64>), + ) -> Arc<dyn ExecutionPlan> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + /// returns a table with 3 columns of i32 in memory + pub fn build_table_i32_nullable( + a: (&str, &Vec<Option<i32>>), + b: (&str, &Vec<Option<i32>>), + c: (&str, &Vec<Option<i32>>), + ) -> Arc<dyn ExecutionPlan> { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_join_schema_for_test( + left: &Schema, + right: &Schema, + join_type: JoinType, + ) -> Result<SchemaRef> { + if join_type == Existence { + let exists_field = Arc::new(Field::new("exists#0", DataType::Boolean, false)); + return Ok(Arc::new(Schema::new( + [left.fields().to_vec(), vec![exists_field]].concat(), + ))); + } + Ok(Arc::new( + build_join_schema(left, right, &join_type.try_into()?).0, + )) + } + + async fn join_collect( + test_type: TestType, + left: Arc<dyn ExecutionPlan>, + right: Arc<dyn ExecutionPlan>, + on: JoinOn, + join_type: JoinType, + ) -> Result<(Vec<String>, Vec<RecordBatch>)> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?; + + let join: Arc<dyn ExecutionPlan> = match test_type { + SMJ => { + let sort_options = vec![SortOptions::default(); on.len()]; + Arc::new(SortMergeJoinExec::try_new( + schema, + left, + right, + on, + join_type, + sort_options, + )?) + } + BHJLeftProbed => { + let right = Arc::new(BroadcastJoinBuildHashMapExec::new( + right, + on.iter().map(|(_, right_key)| right_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Right, + None, + )?) + } + BHJRightProbed => { + let left = Arc::new(BroadcastJoinBuildHashMapExec::new( + left, + on.iter().map(|(left_key, _)| left_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Left, + None, + )?) + } + }; + let columns = columns(&join.schema()); + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_two_two() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 1, 2]), + ("b2", &vec![1, 1, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 1, 3]), + ("b2", &vec![1, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 1 | 1 | 7 | 1 | 1 | 80 |", + "| 1 | 1 | 8 | 1 | 1 | 70 |", + "| 1 | 1 | 8 | 1 | 1 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_with_nulls() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field + ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), + ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| | | | 30 | 6 | 90 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, LeftAnti).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 5 is double on the right + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, LeftSemi).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_with_duplicated_column_names() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a", &vec![1, 2, 3]), + ("b", &vec![4, 5, 7]), + ("c", &vec![7, 8, 9]), + ); + let right = build_table( + ("a", &vec![10, 20, 30]), + ("b", &vec![1, 2, 7]), + ("c", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + // join on a=b so there are duplicate column names on unjoined columns + Arc::new(Column::new_with_schema("a", &left.schema())?), + Arc::new(Column::new_with_schema("b", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+---+---+---+----+---+----+", + "| a | b | c | a | b | c |", + "+---+---+---+----+---+----+", + "| 1 | 4 | 7 | 10 | 1 | 70 |", + "| 2 | 5 | 8 | 20 | 2 | 80 |", + "+---+---+---+----+---+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_date32() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19108, 19108]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19107, 19108, 19109]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + + let expected = vec![ + "+------------+------------+------------+------------+------------+------------+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+------------+------------+------------+------------+------------+------------+", + "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", + "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", + "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", + "+------------+------------+------------+------------+------------+------------+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_date64() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), /* this has a + * repetition */ + ("c1", &vec![7, 8, 9]), + ); + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", + "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", + "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + ]; + + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_sort_order() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![3, 4, 5, 6, 6, 7]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![2, 4, 6, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_sort_order() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![0, 1, 2, 3]), + ("b1", &vec![3, 4, 5, 7]), + ("c1", &vec![6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30]), + ("b2", &vec![2, 4, 5, 6]), + ("c2", &vec![60, 70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 60 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| | | | 30 | 6 | 90 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "| 6 | 9 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![3, 4, 5, 6]), + ("b2", &vec![6, 6, 7, 9]), + ("c2", &vec![7, 8, 9, 9]), + ); + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 10, 20]), + ("b1", &vec![2, 4, 6]), + ("c1", &vec![50, 60, 70]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![30, 40]), + ("b1", &vec![6, 8]), + ("c1", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 3 | 4 |", + "| 10 | 4 | 60 | 1 | 4 | 5 |", + "| | | | 2 | 5 | 6 |", + "| 20 | 6 | 70 | 3 | 6 | 7 |", + "| 30 | 6 | 80 | 3 | 6 | 7 |", + "| 20 | 6 | 70 | 4 | 6 | 8 |", + "| 30 | 6 | 80 | 4 | 6 | 8 |", + "| | | | 5 | 7 | 9 |", + "| | | | 6 | 9 | 9 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_full_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 50 |", + "| | | | 40 | 8 | 90 |", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "| 6 | 9 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_existence_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Existence).await?; + let expected = vec![ + "+----+----+----+----------+", + "| a1 | b1 | c1 | exists#0 |", + "+----+----+----+----------+", + "| 0 | 3 | 4 | false |", + "| 1 | 4 | 5 | true |", + "| 2 | 5 | 6 | false |", + "| 3 | 6 | 7 | true |", + "| 4 | 6 | 8 | true |", + "| 5 | 7 | 9 | false |", + "| 6 | 9 | 9 | false |", + "+----+----+----+----------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } +}
diff --git a/native-engine/datafusion-ext-plans/src/lib.rs b/native-engine/datafusion-ext-plans/src/lib.rs index a0797fb..744cbb4 100644 --- a/native-engine/datafusion-ext-plans/src/lib.rs +++ b/native-engine/datafusion-ext-plans/src/lib.rs
@@ -14,31 +14,39 @@ #![feature(get_mut_unchecked)] #![feature(io_error_other)] +#![feature(adt_const_params)] -pub mod agg; +// execution plan implementations pub mod agg_exec; +pub mod broadcast_join_build_hash_map_exec; pub mod broadcast_join_exec; pub mod broadcast_nested_loop_join_exec; -pub mod common; pub mod debug_exec; pub mod empty_partitions_exec; pub mod expand_exec; pub mod ffi_reader_exec; pub mod filter_exec; -pub mod generate; pub mod generate_exec; pub mod ipc_reader_exec; pub mod ipc_writer_exec; pub mod limit_exec; -pub mod memmgr; pub mod parquet_exec; pub mod parquet_sink_exec; pub mod project_exec; pub mod rename_columns_exec; pub mod rss_shuffle_writer_exec; -mod shuffle; pub mod shuffle_writer_exec; pub mod sort_exec; pub mod sort_merge_join_exec; -pub mod window; pub mod window_exec; + +// memory management +pub mod memmgr; + +// helper modules +pub mod agg; +pub mod common; +pub mod generate; +pub mod joins; +mod shuffle; +pub mod window;
diff --git a/native-engine/datafusion-ext-plans/src/parquet_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_exec.rs index 8fd5f57..0341e31 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_exec.rs
@@ -20,7 +20,7 @@ use std::{any::Any, fmt, fmt::Formatter, ops::Range, sync::Arc}; use arrow::{ - array::ArrayRef, + array::{Array, ArrayRef, AsArray, ListArray}, datatypes::{DataType, SchemaRef}, }; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; @@ -71,7 +71,61 @@ col: &ArrayRef, data_type: &DataType, ) -> Result<ArrayRef, DataFusionError> { - datafusion_ext_commons::cast::cast_scan_input_array(col.as_ref(), data_type) + macro_rules! handle_decimal { + ($s:ident, $t:ident, $tnative:ty, $prec:expr, $scale:expr) => {{ + use arrow::{array::*, datatypes::*}; + type DecimalBuilder = paste::paste! {[<$t Builder>]}; + type IntType = paste::paste! {[<$s Type>]}; + + let col = col.as_primitive::<IntType>(); + let mut decimal_builder = DecimalBuilder::new(); + for i in 0..col.len() { + if col.is_valid(i) { + decimal_builder.append_value(col.value(i) as $tnative); + } else { + decimal_builder.append_null(); + } + } + Ok(Arc::new( + decimal_builder + .finish() + .with_precision_and_scale($prec, $scale)?, + )) + }}; + } + match data_type { + DataType::Decimal128(prec, scale) => match col.data_type() { + DataType::Int8 => handle_decimal!(Int8, Decimal128, i128, *prec, *scale), + DataType::Int16 => handle_decimal!(Int16, Decimal128, i128, *prec, *scale), + DataType::Int32 => handle_decimal!(Int32, Decimal128, i128, *prec, *scale), + DataType::Int64 => handle_decimal!(Int64, Decimal128, i128, *prec, *scale), + DataType::Decimal128(p, s) if p == prec && s == scale => Ok(col.clone()), + _ => df_execution_err!( + "schema_adapter_cast_column unsupported type: {:?} => {:?}", + col.data_type(), + data_type, + ), + }, + DataType::List(to_field) => match col.data_type() { + DataType::List(_from_field) => { + let col = col.as_list::<i32>(); + let from_inner = col.values(); + let to_inner = schema_adapter_cast_column(from_inner, to_field.data_type())?; + Ok(Arc::new(ListArray::try_new( + to_field.clone(), + col.offsets().clone(), + to_inner, + col.nulls().cloned(), + )?)) + } + _ => df_execution_err!( + "schema_adapter_cast_column unsupported type: {:?} => {:?}", + col.data_type(), + data_type, + ), + }, + _ => datafusion_ext_commons::cast::cast_scan_input_array(col.as_ref(), data_type), + } } /// Execution plan for scanning one or more Parquet partitions @@ -231,6 +285,9 @@ None => (0..self.base_config.file_schema.fields().len()).collect(), }; + let page_filtering_enabled = conf::PARQUET_ENABLE_PAGE_FILTERING.value()?; + let bloom_filter_enabled = conf::PARQUET_ENABLE_BLOOM_FILTER.value()?; + let opener = ParquetOpener { partition_index, projection: Arc::from(projection), @@ -243,10 +300,10 @@ metadata_size_hint: None, metrics: self.metrics.clone(), parquet_file_reader_factory: Arc::new(FsReaderFactory::new(fs_provider)), - pushdown_filters: false, - reorder_filters: false, - enable_page_index: false, - enable_bloom_filter: false, + pushdown_filters: page_filtering_enabled, + reorder_filters: page_filtering_enabled, + enable_page_index: page_filtering_enabled, + enable_bloom_filter: bloom_filter_enabled, }; let baseline_metrics_cloned = baseline_metrics.clone();
diff --git a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs index f2dff1d..69b46cf 100644 --- a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs +++ b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs
@@ -35,7 +35,6 @@ SendableRecordBatchStream, Statistics, }, }; -use datafusion_ext_commons::df_execution_err; use futures::{Stream, StreamExt}; use crate::agg::AGG_BUF_COLUMN_NAME; @@ -56,7 +55,12 @@ let input_schema = input.schema(); let mut new_names = vec![]; - for (i, field) in input_schema.fields().iter().enumerate() { + for (i, field) in input_schema + .fields() + .iter() + .take(renamed_column_names.len()) + .enumerate() + { if field.name() != AGG_BUF_COLUMN_NAME { new_names.push(renamed_column_names[i].clone()); } else { @@ -64,11 +68,9 @@ break; } } - if new_names.len() != input_schema.fields().len() { - df_execution_err!( - "renamed_column_names length not matched with input schema, \ - renames: {renamed_column_names:?}, input schema: {input_schema}", - )?; + + while new_names.len() < input_schema.fields().len() { + new_names.push(input_schema.field(new_names.len()).name().clone()); } let renamed_column_names = new_names; let renamed_schema = Arc::new(Schema::new(
diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index 8459d47..127e0db 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
@@ -12,135 +12,118 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{any::Any, cmp::Ordering, fmt::Formatter, sync::Arc}; - -use arrow::{ - array::*, - buffer::NullBuffer, - compute::{prep_null_mask_filter, SortOptions}, - datatypes::{DataType, Schema, SchemaRef}, - record_batch::{RecordBatch, RecordBatchOptions}, - row::{Row, RowConverter, Rows, SortField}, +use std::{ + any::Any, + fmt::Formatter, + pin::Pin, + sync::Arc, + time::{Duration, Instant}, }; + +use arrow::{compute::SortOptions, datatypes::SchemaRef}; +use async_trait::async_trait; use datafusion::{ - common::JoinSide, + common::{DataFusionError, JoinSide}, error::Result, execution::context::TaskContext, - logical_expr::{JoinType, JoinType::*}, - physical_expr::{expressions::Column, PhysicalSortExpr}, + physical_expr::{PhysicalExprRef, PhysicalSortExpr}, physical_plan::{ - joins::utils::{build_join_schema, check_join_is_valid, ColumnIndex, JoinFilter, JoinOn}, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, ScopedTimerGuard}, + joins::utils::JoinOn, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, }; use datafusion_ext_commons::{ - array_size::ArraySize, batch_size, df_execution_err, downcast_any, - streams::coalesce_stream::CoalesceInput, suggested_output_batch_mem_size, + batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput, }; -use futures::{StreamExt, TryStreamExt}; -use parking_lot::Mutex as SyncMutex; +use futures::TryStreamExt; -use crate::common::{ - batch_selection::{interleave_batches, take_batch_opt}, - column_pruning::ExecuteWithColumnPruning, - output::{TaskOutputter, WrappedRecordBatchSender}, +use crate::{ + common::output::{TaskOutputter, WrappedRecordBatchSender}, + cur_forward, + joins::{ + join_utils::{JoinType, JoinType::*}, + smj::{ + existence_join::ExistenceJoiner, + full_join::{FullOuterJoiner, InnerJoiner, LeftOuterJoiner, RightOuterJoiner}, + semi_join::{LeftAntiJoiner, LeftSemiJoiner, RightAntiJoiner, RightSemiJoiner}, + }, + stream_cursor::StreamCursor, + JoinParams, StreamCursors, + }, }; #[derive(Debug)] pub struct SortMergeJoinExec { - /// Left sorted joining execution plan left: Arc<dyn ExecutionPlan>, - /// Right sorting joining execution plan right: Arc<dyn ExecutionPlan>, - /// Set of common columns used to join on on: JoinOn, - /// How the join is performed join_type: JoinType, - /// Optional filter before outputting - join_filter: Option<JoinFilter>, - /// The schema once the join is applied - schema: SchemaRef, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, - /// Sort options of join columns used in sorting left and right execution - /// plans sort_options: Vec<SortOptions>, + schema: SchemaRef, + metrics: ExecutionPlanMetricsSet, } impl SortMergeJoinExec { pub fn try_new( + schema: SchemaRef, left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>, on: JoinOn, join_type: JoinType, - join_filter: Option<JoinFilter>, sort_options: Vec<SortOptions>, ) -> Result<Self> { - let left_schema = left.schema(); - let right_schema = right.schema(); - - if matches!(join_type, LeftSemi | LeftAnti | RightSemi | RightAnti,) { - if join_filter.is_some() { - df_execution_err!("Semi/Anti join with filter is not supported yet")?; - } - } - - check_join_is_valid(&left_schema, &right_schema, &on)?; - if sort_options.len() != on.len() { - df_execution_err!( - "Expected number of sort options: {}, actual: {}", - on.len(), - sort_options.len(), - )?; - } - - let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); Ok(Self { + schema, left, right, on, join_type, - join_filter, - schema, - metrics: ExecutionPlanMetricsSet::new(), sort_options, + metrics: ExecutionPlanMetricsSet::new(), }) } - fn create_join_params(&self, batch_size: usize) -> JoinParams { - let on_left: Vec<usize> = self + fn create_join_params(&self) -> Result<JoinParams> { + let left_schema = self.left.schema(); + let right_schema = self.right.schema(); + let (left_keys, right_keys): (Vec<PhysicalExprRef>, Vec<PhysicalExprRef>) = + self.on.iter().cloned().unzip(); + let key_data_types = self .on .iter() - .map(|on| downcast_any!(on.0, Column).unwrap().index()) - .collect(); - let on_right: Vec<usize> = self - .on - .iter() - .map(|on| downcast_any!(on.1, Column).unwrap().index()) - .collect(); - let on_data_types = on_left - .iter() - .map(|&i| self.left.schema().field(i).data_type().clone()) - .collect::<Vec<_>>(); + .map(|(left_key, right_key)| { + Ok({ + let left_dt = left_key.data_type(&left_schema)?; + let right_dt = right_key.data_type(&right_schema)?; + if left_dt != right_dt { + df_execution_err!( + "join key data type differs {left_dt:?} <-> {right_dt:?}" + )?; + } + left_dt + }) + }) + .collect::<Result<_>>()?; + + let batch_size = batch_size(); let sub_batch_size = batch_size / batch_size.ilog10() as usize; // use smaller batch size and coalesce batches at the end, to avoid buffer // overflowing - JoinParams { + Ok(JoinParams { join_type: self.join_type, + left_schema, + right_schema, output_schema: self.schema(), - on_left, - on_right, - on_data_types, - join_filter: self.join_filter.clone(), + left_keys, + right_keys, + key_data_types, sort_options: self.sort_options.clone(), batch_size: sub_batch_size, - left_output_projection: (0..self.left.schema().fields().len()).collect(), - right_output_projection: (0..self.right.schema().fields().len()).collect(), - } + }) } } @@ -169,7 +152,7 @@ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { match self.join_type { - Left | LeftSemi | LeftAnti => self.left.output_ordering(), + Left | LeftSemi | LeftAnti | Existence => self.left.output_ordering(), Right | RightSemi | RightAnti => self.right.output_ordering(), Inner => self.left.output_ordering(), Full => None, @@ -185,11 +168,11 @@ children: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>> { Ok(Arc::new(SortMergeJoinExec::try_new( + self.schema(), children[0].clone(), children[1].clone(), self.on.clone(), self.join_type, - self.join_filter.clone(), self.sort_options.clone(), )?)) } @@ -200,11 +183,23 @@ context: Arc<TaskContext>, ) -> Result<SendableRecordBatchStream> { let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); - let batch_size = batch_size(); - let join_params = self.create_join_params(batch_size); + let join_params = self.create_join_params()?; let left = self.left.execute(partition, context.clone())?; let right = self.right.execute(partition, context.clone())?; - execute_with_join_params(context, join_params, left, right, metrics) + let output_schema = self.schema(); + + let metrics_cloned = metrics.clone(); + let context_cloned = context.clone(); + let output_stream = Box::pin(RecordBatchStreamAdapter::new( + output_schema.clone(), + futures::stream::once(async move { + context_cloned.output_with_sender("SortMergeJoin", output_schema, move |sender| { + execute_join(left, right, join_params, metrics_cloned, sender) + }) + }) + .try_flatten(), + )); + Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) } fn metrics(&self) -> Option<MetricsSet> { @@ -216,1549 +211,66 @@ } } -impl ExecuteWithColumnPruning for SortMergeJoinExec { - fn execute_projected( - &self, - partition: usize, - context: Arc<TaskContext>, - projection: &[usize], - ) -> Result<SendableRecordBatchStream> { - let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); - let batch_size = batch_size(); - - let (join_params, left_projection, right_projection) = - self.create_join_params(batch_size).project(projection)?; - let left = self - .left - .execute_projected(partition, context.clone(), &left_projection)?; - let right = self - .right - .execute_projected(partition, context.clone(), &right_projection)?; - execute_with_join_params(context, join_params, left, right, metrics) - } -} - -#[derive(Clone)] -struct JoinParams { - join_type: JoinType, - output_schema: SchemaRef, - on_left: Vec<usize>, - on_right: Vec<usize>, - on_data_types: Vec<DataType>, - sort_options: Vec<SortOptions>, - join_filter: Option<JoinFilter>, - left_output_projection: Vec<usize>, - right_output_projection: Vec<usize>, - batch_size: usize, -} - -impl JoinParams { - fn project(&self, projection: &[usize]) -> Result<(Self, Vec<usize>, Vec<usize>)> { - let num_left_fields = self.left_output_projection.len(); - let mut left_projection = vec![]; - let mut right_projection = vec![]; - - for &i in projection { - match self.join_type { - Inner | Left | Right | Full => { - if i < num_left_fields { - left_projection.push(i); - } else { - right_projection.push(i - num_left_fields); - } - } - LeftSemi | LeftAnti => { - left_projection.push(i); - } - RightSemi | RightAnti => { - right_projection.push(i); - } - } - } - let num_left_output_columns = left_projection.len(); - let num_right_output_columns = right_projection.len(); - - let mut on_left_projected = vec![]; - let mut on_right_projected = vec![]; - for &l in &self.on_left { - on_left_projected.push(left_projection.iter().position(|&i| i == l).unwrap_or_else( - || { - left_projection.push(l); - left_projection.len() - 1 - }, - )); - } - for &r in &self.on_right { - on_right_projected.push( - right_projection - .iter() - .position(|&i| i == r) - .unwrap_or_else(|| { - right_projection.push(r); - right_projection.len() - 1 - }), - ); - } - - let mut join_filter_projected = None; - if let Some(join_filter) = &self.join_filter { - join_filter_projected = Some(JoinFilter::new( - join_filter.expression().clone(), - join_filter - .column_indices() - .iter() - .map(|ci| { - let projected_index = match ci.side { - JoinSide::Left => left_projection - .iter() - .position(|&i| i == ci.index) - .unwrap_or_else(|| { - left_projection.push(ci.index); - left_projection.len() - 1 - }), - JoinSide::Right => right_projection - .iter() - .position(|&i| i == ci.index) - .unwrap_or_else(|| { - right_projection.push(ci.index); - right_projection.len() - 1 - }), - }; - ColumnIndex { - index: projected_index, - side: ci.side, - } - }) - .collect(), - join_filter.schema().clone(), - )); - } - - let projected = Self { - join_type: self.join_type, - output_schema: Arc::new(self.output_schema.project(projection)?), - on_left: on_left_projected, - on_right: on_right_projected, - on_data_types: self.on_data_types.clone(), - sort_options: self.sort_options.clone(), - join_filter: join_filter_projected, - batch_size: self.batch_size, - left_output_projection: (0..num_left_output_columns).collect(), - right_output_projection: (0..num_right_output_columns).collect(), - }; - Ok((projected, left_projection, right_projection)) - } -} - -fn execute_with_join_params( - context: Arc<TaskContext>, - join_params: JoinParams, - left: SendableRecordBatchStream, - right: SendableRecordBatchStream, - metrics: Arc<BaselineMetrics>, -) -> Result<SendableRecordBatchStream> { - let metrics_cloned = metrics.clone(); - let context_cloned = context.clone(); - let output_schema = join_params.output_schema.clone(); - let output_stream = Box::pin(RecordBatchStreamAdapter::new( - join_params.output_schema.clone(), - futures::stream::once(async move { - context_cloned.output_with_sender("SortMergeJoin", output_schema, move |sender| { - execute_join(left, right, join_params, metrics_cloned, sender) - }) - }) - .try_flatten(), - )); - Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) -} - -async fn execute_join( +pub async fn execute_join( lstream: SendableRecordBatchStream, rstream: SendableRecordBatchStream, join_params: JoinParams, metrics: Arc<BaselineMetrics>, sender: Arc<WrappedRecordBatchSender>, ) -> Result<()> { - let elapsed_time = metrics.elapsed_compute().clone(); - let mut timer = elapsed_time.timer(); + let start_time = Instant::now(); - let on_row_converter = Arc::new(SyncMutex::new(RowConverter::new( - join_params - .on_data_types - .iter() - .zip(&join_params.sort_options) - .map(|(data_type, sort_option)| { - SortField::new_with_options(data_type.clone(), *sort_option) - }) - .collect(), - )?)); + let mut curs = ( + StreamCursor::try_new(lstream, &join_params, JoinSide::Left)?, + StreamCursor::try_new(rstream, &join_params, JoinSide::Right)?, + ); - let mut lcur = StreamCursor::try_new( - lstream, - on_row_converter.clone(), - join_params.on_left.clone(), - join_params.left_output_projection.clone(), + // start first batches of both side asynchronously + tokio::try_join!( + async { Ok::<_, DataFusionError>(cur_forward!(curs.0)) }, + async { Ok::<_, DataFusionError>(cur_forward!(curs.1)) }, )?; - let mut rcur = StreamCursor::try_new( - rstream, - on_row_converter.clone(), - join_params.on_right.clone(), - join_params.right_output_projection.clone(), - )?; - - macro_rules! forward { - ($cur:expr) => {{ - if $cur.next() == NextAction::LoadNextBatch { - $cur.next_batch(&mut timer).await?; - } - }}; - } - - // load first record - forward!(lcur); - forward!(rcur); let join_type = join_params.join_type; - let mut joiner = Joiner::new(); - let mut leqs = vec![]; - let mut reqs = vec![]; + let mut joiner: Pin<Box<dyn Joiner + Send>> = match join_type { + Inner => Box::pin(InnerJoiner::new(join_params, sender)), + Left => Box::pin(LeftOuterJoiner::new(join_params, sender)), + Right => Box::pin(RightOuterJoiner::new(join_params, sender)), + Full => Box::pin(FullOuterJoiner::new(join_params, sender)), + LeftSemi => Box::pin(LeftSemiJoiner::new(join_params, sender)), + RightSemi => Box::pin(RightSemiJoiner::new(join_params, sender)), + LeftAnti => Box::pin(LeftAntiJoiner::new(join_params, sender)), + RightAnti => Box::pin(RightAntiJoiner::new(join_params, sender)), + Existence => Box::pin(ExistenceJoiner::new(join_params, sender)), + }; + joiner.as_mut().join(&mut curs).await?; + metrics.record_output(joiner.num_output_rows()); - macro_rules! joiner_accept_pair { - ($lidx:expr, $ridx:expr) => {{ - let lidx = $lidx; - let ridx = $ridx; - let r = joiner.accept_pair(&join_params, &mut lcur, &mut rcur, lidx, ridx)?; - if let Some(batch) = r { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - }}; - } - - // process records until one side is exhausted - while !lcur.finished && !rcur.finished { - let r = compare_cursor(&lcur, lcur.cur_idx, &rcur, rcur.cur_idx); - match r { - Ordering::Less => { - if matches!(join_type, Left | LeftAnti | Full) { - joiner_accept_pair!(Some(lcur.cur_idx), None); - } - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - Ordering::Greater => { - if matches!(join_type, Right | RightAnti | Full) { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - } - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - Ordering::Equal => { - let lidx0 = lcur.cur_idx; - let ridx0 = rcur.cur_idx; - leqs.push(lidx0); - reqs.push(ridx0); - forward!(lcur); - forward!(rcur); - - let mut leq = true; - let mut req = true; - while leq && req { - if leq && !lcur.finished && lcur.row(lcur.cur_idx) == lcur.row(lidx0) { - leqs.push(lcur.cur_idx); - forward!(lcur); - } else { - leq = false; - } - if req && !rcur.finished && rcur.row(rcur.cur_idx) == rcur.row(ridx0) { - reqs.push(rcur.cur_idx); - forward!(rcur); - } else { - req = false; - } - } - - match join_type { - Inner | Left | Right | Full => { - for &l in &leqs { - for &r in &reqs { - joiner_accept_pair!(Some(l), Some(r)); - } - } - } - LeftSemi => { - for &l in &leqs { - joiner_accept_pair!(Some(l), None); - } - } - RightSemi => { - for &r in &reqs { - joiner_accept_pair!(None, Some(r)); - } - } - LeftAnti | RightAnti => {} - } - - if leq { - while !lcur.finished && lcur.row(lcur.cur_idx) == rcur.row(ridx0) { - match join_type { - Inner | Left | Right | Full => { - for &r in &reqs { - joiner_accept_pair!(Some(lcur.cur_idx), Some(r)); - } - } - LeftSemi => { - joiner_accept_pair!(Some(lcur.cur_idx), None); - } - RightSemi | LeftAnti | RightAnti => {} - } - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - } - if req { - while !rcur.finished && rcur.row(rcur.cur_idx) == lcur.row(lidx0) { - match join_type { - Inner | Left | Right | Full => { - for &l in &leqs { - joiner_accept_pair!(Some(l), Some(rcur.cur_idx)); - } - } - RightSemi => { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - } - LeftSemi | LeftAnti | RightAnti => {} - } - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - leqs.clear(); - reqs.clear(); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - - // flush joiner if cursors buffered too many batches - if !joiner.is_empty() && (lcur.num_buffered_batches() + rcur.num_buffered_batches() > 5) - || (lcur.mem_size() + rcur.mem_size() > suggested_output_batch_mem_size() - && lcur.num_buffered_batches() > 1 - && rcur.num_buffered_batches() > 1) - { - if let Some(batch) = joiner.flush_pairs(&join_params, &mut lcur, &mut rcur)? { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - } - } - - // process rest records in inexhausted side - if matches!(join_type, Left | LeftAnti | Full) { - while !lcur.finished { - joiner_accept_pair!(Some(lcur.cur_idx), None); - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - } - if matches!(join_type, Right | RightAnti | Full) { - while !rcur.finished { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - - // flush joiner - if !joiner.is_empty() { - if let Some(batch) = joiner.flush_pairs(&join_params, &mut lcur, &mut rcur)? { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - } + // discount poll input and send output batch time + let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64; + join_time_ns -= joiner.total_send_output_time() as u64; + join_time_ns -= curs.0.total_poll_time() as u64; + join_time_ns -= curs.1.total_poll_time() as u64; + metrics + .elapsed_compute() + .add_duration(Duration::from_nanos(join_time_ns)); Ok(()) } -struct StreamCursor { - stream: SendableRecordBatchStream, - on_row_converter: Arc<SyncMutex<RowConverter>>, - on_columns: Vec<usize>, - - // IMPORTANT: - // batches/rows/null_buffers always contains a `null batch` in the front - batches: Vec<RecordBatch>, - projected_batches: Vec<RecordBatch>, - projection: Vec<usize>, - on_rows: Vec<Arc<Rows>>, - on_row_null_buffers: Vec<Option<NullBuffer>>, - cur_idx: (usize, usize), - num_null_batches: usize, - mem_size: usize, - finished: bool, +#[macro_export] +macro_rules! compare_cursor { + ($curs:expr) => {{ + match ($curs.0.cur_idx, $curs.1.cur_idx) { + (lidx, _) if $curs.0.is_null_key(lidx) => Ordering::Less, + (_, ridx) if $curs.1.is_null_key(ridx) => Ordering::Greater, + (lidx, ridx) => $curs.0.key(lidx).cmp(&$curs.1.key(ridx)), + } + }}; } -#[derive(Clone, Copy, PartialEq, Eq)] -enum NextAction { - None, - LoadNextBatch, -} - -impl StreamCursor { - fn try_new( - stream: SendableRecordBatchStream, - on_row_converter: Arc<SyncMutex<RowConverter>>, - on_columns: Vec<usize>, - projection: Vec<usize>, - ) -> Result<Self> { - let empty_batch = RecordBatch::new_empty(Arc::new(Schema::new( - stream - .schema() - .fields() - .iter() - .map(|f| f.as_ref().clone().with_nullable(true)) - .collect::<Vec<_>>(), - ))); - let null_batch = take_batch_opt(empty_batch, [Option::<usize>::None])?; - let null_on_rows = Arc::new( - on_row_converter - .lock() - .convert_columns(null_batch.project(&on_columns)?.columns())?, - ); - let null_nb = NullBuffer::new_null(1); - - Ok(Self { - stream, - on_row_converter, - on_columns, - projected_batches: vec![null_batch.project(&projection)?], - batches: vec![null_batch], - projection, - on_rows: vec![null_on_rows], - on_row_null_buffers: vec![Some(null_nb)], - cur_idx: (0, 0), - num_null_batches: 1, - mem_size: 0, - finished: false, - }) - } - - fn next(&mut self) -> NextAction { - let mut next_action = NextAction::None; - let mut cur_idx = self.cur_idx; - - if cur_idx.1 + 1 < self.batches[cur_idx.0].num_rows() { - cur_idx.1 += 1; - } else { - cur_idx.0 += 1; - cur_idx.1 = 0; - next_action = NextAction::LoadNextBatch; - } - self.cur_idx = cur_idx; - next_action - } - - async fn next_batch(&mut self, stop_timer: &mut ScopedTimerGuard<'_>) -> Result<bool> { - stop_timer.stop(); - if let Some(batch) = self.stream.next().await.transpose()? { - stop_timer.restart(); - let on_columns = batch.project(&self.on_columns)?.columns().to_vec(); - let on_row_null_buffer = on_columns - .iter() - .map(|c| c.nulls().cloned()) - .reduce(|lhs, rhs| NullBuffer::union(lhs.as_ref(), rhs.as_ref())) - .unwrap_or(None); - let on_rows = Arc::new(self.on_row_converter.lock().convert_columns(&on_columns)?); - - self.mem_size += batch.get_array_mem_size(); - self.mem_size += on_row_null_buffer - .as_ref() - .map(|nb| nb.buffer().len()) - .unwrap_or_default(); - self.mem_size += on_rows.size(); - - self.projected_batches - .push(batch.project(&self.projection)?); - self.batches.push(batch); - self.on_row_null_buffers.push(on_row_null_buffer); - self.on_rows.push(on_rows); - return Ok(true); - } else { - stop_timer.restart(); - } - self.finished = true; - Ok(false) - } - - #[inline] - fn row<'a>(&'a self, idx: (usize, usize)) -> Row<'a> { - let bidx = idx.0; - let ridx = idx.1; - self.on_rows[bidx].row(ridx) - } - - #[inline] - fn num_buffered_batches(&self) -> usize { - self.batches.len() - self.num_null_batches - } - - #[inline] - fn mem_size(&self) -> usize { - self.mem_size - } - - #[inline] - fn clear_outdated(&mut self, min_reserved_bidx: usize) { - // fill out-dated batches with null batches - for i in self.num_null_batches..min_reserved_bidx.min(self.cur_idx.0) { - self.mem_size -= self.batches[i].get_array_mem_size(); - self.mem_size -= self.on_row_null_buffers[i] - .as_ref() - .map(|nb| nb.buffer().len()) - .unwrap_or_default(); - self.mem_size -= self.on_rows[i].size(); - - self.projected_batches[i] = self.projected_batches[0].clone(); - self.batches[i] = self.batches[0].clone(); - self.on_rows[i] = self.on_rows[0].clone(); - self.on_row_null_buffers[i] = self.on_row_null_buffers[0].clone(); - self.num_null_batches += 1; - } - } -} - -#[derive(Default)] -struct Joiner { - ljoins: Vec<(usize, usize)>, - rjoins: Vec<(usize, usize)>, - l_min_reserved_bidx: usize, - r_min_reserved_bidx: usize, -} - -impl Joiner { - fn new() -> Self { - Self { - ljoins: vec![], - rjoins: vec![], - l_min_reserved_bidx: usize::MAX, - r_min_reserved_bidx: usize::MAX, - } - } - - fn accept_pair( - &mut self, - join_params: &JoinParams, - lcur: &mut StreamCursor, - rcur: &mut StreamCursor, - l: Option<(usize, usize)>, - r: Option<(usize, usize)>, - ) -> Result<Option<RecordBatch>> { - if let Some((bidx, ridx)) = l { - self.ljoins.push((bidx, ridx)); - self.l_min_reserved_bidx = self.l_min_reserved_bidx.min(bidx); - } else { - self.ljoins.push((0, 0)); - } - - if let Some((bidx, ridx)) = r { - self.rjoins.push((bidx, ridx)); - self.r_min_reserved_bidx = self.r_min_reserved_bidx.min(bidx); - } else { - self.rjoins.push((0, 0)); - } - - let batch_size = join_params.batch_size; - if self.ljoins.len() >= batch_size || self.rjoins.len() >= batch_size { - return self.flush_pairs(join_params, lcur, rcur); - } - Ok(None) - } - - fn is_empty(&self) -> bool { - self.ljoins.is_empty() && self.rjoins.is_empty() - } - - fn flush_pairs( - &mut self, - join_params: &JoinParams, - lcur: &mut StreamCursor, - rcur: &mut StreamCursor, - ) -> Result<Option<RecordBatch>> { - self.l_min_reserved_bidx = usize::MAX; - self.r_min_reserved_bidx = usize::MAX; - - if let Some(join_filter) = &join_params.join_filter { - let num_intermediate_rows = std::cmp::max(self.ljoins.len(), self.rjoins.len()); - - // get intermediate batch - let intermediate_columns = join_filter - .column_indices() - .iter() - .map(|ci| { - let (cur, joins) = match ci.side { - JoinSide::Left => (&lcur, &self.ljoins), - JoinSide::Right => (&rcur, &self.rjoins), - }; - let arrays = cur - .batches - .iter() - .map(|b| b.column(ci.index).as_ref()) - .collect::<Vec<_>>(); - Ok(arrow::compute::interleave(&arrays, joins)?) - }) - .collect::<Result<Vec<_>>>()?; - - let intermediate_batch = RecordBatch::try_new_with_options( - Arc::new(join_filter.schema().clone()), - intermediate_columns, - &RecordBatchOptions::new().with_row_count(Some(num_intermediate_rows)), - )?; - - // evalute filter - let filtered_array = join_filter - .expression() - .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows())?; - let filtered = as_boolean_array(&filtered_array); - let filtered = if filtered.null_count() > 0 { - prep_null_mask_filter(filtered) - } else { - filtered.clone() - }; - - // apply filter - let mut retained = 0; - for (i, selected) in filtered.values().iter().enumerate() { - if selected { - self.ljoins[retained] = self.ljoins[i]; - self.rjoins[retained] = self.rjoins[i]; - retained += 1; - } - } - self.ljoins.truncate(retained); - self.rjoins.truncate(retained); - if retained == 0 { - return Ok(None); - } - } - - let lcols = || -> Result<Vec<ArrayRef>> { - Ok(if !lcur.projection.is_empty() { - interleave_batches( - lcur.projected_batches[0].schema(), - &lcur.projected_batches, - &self.ljoins, - )? - .columns() - .to_vec() - } else { - vec![] - }) - }; - let rcols = || -> Result<Vec<ArrayRef>> { - Ok(if !rcur.projection.is_empty() { - interleave_batches( - rcur.projected_batches[0].schema(), - &rcur.projected_batches, - &self.rjoins, - )? - .columns() - .to_vec() - } else { - vec![] - }) - }; - - let output_columns = match join_params.join_type { - LeftSemi | LeftAnti => lcols()?, - RightSemi | RightAnti => rcols()?, - _ => [lcols()?, rcols()?].concat(), - }; - let num_output_records = std::cmp::max(self.ljoins.len(), self.rjoins.len()); - self.ljoins.clear(); - self.rjoins.clear(); - let batch = RecordBatch::try_new_with_options( - join_params.output_schema.clone(), - output_columns, - &RecordBatchOptions::new().with_row_count(Some(num_output_records)), - )?; - Ok(Some(batch)) - } -} - -fn compare_cursor( - lcur: &StreamCursor, - lidx: (usize, usize), - rcur: &StreamCursor, - ridx: (usize, usize), -) -> Ordering { - match (&lcur.on_rows.get(lidx.0), &rcur.on_rows.get(ridx.0)) { - (None, _) => Ordering::Greater, - (_, None) => Ordering::Less, - (Some(lrows), Some(rrows)) => { - let lkey = &lrows.row(lidx.1); - let rkey = &rrows.row(ridx.1); - match lkey.cmp(rkey) { - Ordering::Greater => Ordering::Greater, - Ordering::Less => Ordering::Less, - _ => { - if let Some(nb) = &lcur.on_row_null_buffers[lidx.0] { - if nb.is_null(lidx.1) { - return Ordering::Less; - } - } - Ordering::Equal - } - } - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::{ - self, - array::*, - compute::SortOptions, - datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, - }; - use datafusion::{ - assert_batches_sorted_eq, - error::Result, - logical_expr::{JoinType, JoinType::*}, - physical_expr::expressions::Column, - physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan}, - prelude::SessionContext, - }; - - use crate::sort_merge_join_exec::SortMergeJoinExec; - - fn columns(schema: &Schema) -> Vec<String> { - schema.fields().iter().map(|f| f.name().clone()).collect() - } - - fn build_table_i32( - a: (&str, &Vec<i32>), - b: (&str, &Vec<i32>), - c: (&str, &Vec<i32>), - ) -> RecordBatch { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Int32, false), - Field::new(b.0, DataType::Int32, false), - Field::new(c.0, DataType::Int32, false), - ]); - - RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap() - } - - fn build_table( - a: (&str, &Vec<i32>), - b: (&str, &Vec<i32>), - c: (&str, &Vec<i32>), - ) -> Arc<dyn ExecutionPlan> { - let batch = build_table_i32(a, b, c); - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> { - let schema = batches.first().unwrap().schema(); - Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap()) - } - - fn build_date_table( - a: (&str, &Vec<i32>), - b: (&str, &Vec<i32>), - c: (&str, &Vec<i32>), - ) -> Arc<dyn ExecutionPlan> { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date32, false), - Field::new(b.0, DataType::Date32, false), - Field::new(c.0, DataType::Date32, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date32Array::from(a.1.clone())), - Arc::new(Date32Array::from(b.1.clone())), - Arc::new(Date32Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn build_date64_table( - a: (&str, &Vec<i64>), - b: (&str, &Vec<i64>), - c: (&str, &Vec<i64>), - ) -> Arc<dyn ExecutionPlan> { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date64, false), - Field::new(b.0, DataType::Date64, false), - Field::new(c.0, DataType::Date64, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date64Array::from(a.1.clone())), - Arc::new(Date64Array::from(b.1.clone())), - Arc::new(Date64Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - /// returns a table with 3 columns of i32 in memory - pub fn build_table_i32_nullable( - a: (&str, &Vec<Option<i32>>), - b: (&str, &Vec<Option<i32>>), - c: (&str, &Vec<Option<i32>>), - ) -> Arc<dyn ExecutionPlan> { - let schema = Arc::new(Schema::new(vec![ - Field::new(a.0, DataType::Int32, true), - Field::new(b.0, DataType::Int32, true), - Field::new(c.0, DataType::Int32, true), - ])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn join_with_options( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - on: JoinOn, - join_type: JoinType, - sort_options: Vec<SortOptions>, - ) -> Result<SortMergeJoinExec> { - SortMergeJoinExec::try_new(left, right, on, join_type, None, sort_options) - } - - async fn join_collect( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - on: JoinOn, - join_type: JoinType, - ) -> Result<(Vec<String>, Vec<RecordBatch>)> { - let sort_options = vec![SortOptions::default(); on.len()]; - join_collect_with_options(left, right, on, join_type, sort_options).await - } - - async fn join_collect_with_options( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - on: JoinOn, - join_type: JoinType, - sort_options: Vec<SortOptions>, - ) -> Result<(Vec<String>, Vec<RecordBatch>)> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let join = join_with_options(left, right, on, join_type, sort_options)?; - let columns = columns(&join.schema()); - - let stream = join.execute(0, task_ctx)?; - let batches = common::collect(stream).await?; - Ok((columns, batches)) - } - - #[tokio::test] - async fn join_inner_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 5]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 5 | 9 | 20 | 5 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b2", &vec![1, 2, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 2, 3]), - ("b2", &vec![1, 2, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | 7 | 1 | 1 | 70 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 1, 2]), - ("b2", &vec![1, 1, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 1, 3]), - ("b2", &vec![1, 1, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | 7 | 1 | 1 | 70 |", - "| 1 | 1 | 7 | 1 | 1 | 80 |", - "| 1 | 1 | 8 | 1 | 1 | 70 |", - "| 1 | 1 | 8 | 1 | 1 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field - ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), - ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | | 1 | 1 | 70 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls_with_options() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field - ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), - ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - let (_, batches) = join_collect_with_options( - left, - right, - on, - Inner, - vec![ - SortOptions { - descending: true, - nulls_first: false - }; - 2 - ], - // null_equals_null=false - ) - .await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 1 | 1 | | 1 | 1 | 70 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 6 does not exist on the left - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| | | | 30 | 6 | 90 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_full_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b2", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 30 | 6 | 90 |", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_anti() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3, 5]), - ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9, 11]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, LeftAnti).await?; - let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 3 | 7 | 9 |", - "| 5 | 7 | 11 |", - "+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_semi() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 5 is double on the right - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, LeftSemi).await?; - let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 1 | 4 | 7 |", - "| 2 | 5 | 8 |", - "| 2 | 5 | 8 |", - "+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_with_duplicated_column_names() -> Result<()> { - let left = build_table( - ("a", &vec![1, 2, 3]), - ("b", &vec![4, 5, 7]), - ("c", &vec![7, 8, 9]), - ); - let right = build_table( - ("a", &vec![10, 20, 30]), - ("b", &vec![1, 2, 7]), - ("c", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - // join on a=b so there are duplicate column names on unjoined columns - Arc::new(Column::new_with_schema("a", &left.schema())?), - Arc::new(Column::new_with_schema("b", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+---+---+---+----+---+----+", - "| a | b | c | a | b | c |", - "+---+---+---+----+---+----+", - "| 1 | 4 | 7 | 10 | 1 | 70 |", - "| 2 | 5 | 8 | 20 | 2 | 80 |", - "+---+---+---+----+---+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_date32() -> Result<()> { - let left = build_date_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![19107, 19108, 19108]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![19107, 19108, 19109]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - let expected = vec![ - "+------------+------------+------------+------------+------------+------------+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+------------+------------+------------+------------+------------+------------+", - "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", - "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", - "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", - "+------------+------------+------------+------------+------------+------------+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_date64() -> Result<()> { - let left = build_date64_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date64_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", - "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", - "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - ]; - - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3, 4, 5]), - ("b1", &vec![3, 4, 5, 6, 6, 7]), - ("c1", &vec![4, 5, 6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30, 40]), - ("b2", &vec![2, 4, 6, 6, 8]), - ("c2", &vec![50, 60, 70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3]), - ("b1", &vec![3, 4, 5, 7]), - ("c1", &vec![6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30]), - ("b2", &vec![2, 4, 5, 6]), - ("c2", &vec![60, 70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 2 | 60 |", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| | | | 30 | 6 | 90 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "| 6 | 9 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_multiple_batches() -> Result<()> { - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 1, 2]), - ("b2", &vec![3, 4, 5]), - ("c2", &vec![4, 5, 6]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![3, 4, 5, 6]), - ("b2", &vec![6, 6, 7, 9]), - ("c2", &vec![7, 8, 9, 9]), - ); - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 10, 20]), - ("b1", &vec![2, 4, 6]), - ("c1", &vec![50, 60, 70]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![30, 40]), - ("b1", &vec![6, 8]), - ("c1", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 3 | 4 |", - "| 10 | 4 | 60 | 1 | 4 | 5 |", - "| | | | 2 | 5 | 6 |", - "| 20 | 6 | 70 | 3 | 6 | 7 |", - "| 30 | 6 | 80 | 3 | 6 | 7 |", - "| 20 | 6 | 70 | 4 | 6 | 8 |", - "| 30 | 6 | 80 | 4 | 6 | 8 |", - "| | | | 5 | 7 | 9 |", - "| | | | 6 | 9 | 9 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_full_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 2 | 50 |", - "| | | | 40 | 8 | 90 |", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "| 6 | 9 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } +#[async_trait] +pub trait Joiner { + async fn join(self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()>; + fn total_send_output_time(&self) -> usize; + fn num_output_rows(&self) -> usize; }
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 1d867a5..ba127ef 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
@@ -150,7 +150,7 @@ leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase = + broadcastSide: BroadcastSide): NativeBroadcastJoinBase = NativeBroadcastJoinExec( left, right, @@ -158,7 +158,7 @@ leftKeys, rightKeys, joinType, - condition) + broadcastSide) override def createNativeBroadcastNestedLoopJoinExec( left: SparkPlan,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala index 292f233..fdd3a24 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala
@@ -20,7 +20,6 @@ import org.apache.spark.MapOutputTracker import org.apache.spark.SparkEnv import org.apache.spark.TaskContext - import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec @@ -28,30 +27,21 @@ import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.storage.BlockId import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.ShuffleBlockFetcherIterator class BlazeBlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], context: TaskContext, readMetrics: ShuffleReadMetricsReporter, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, - startMapId: Option[Int] = None, - endMapId: Option[Int] = None, shouldBatchFetch: Boolean = false) extends BlazeBlockStoreShuffleReaderBase[K, C](handle, context) with Logging { override def readBlocks(): Iterator[(BlockId, InputStream)] = { - val blocksByAddress = mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, - startMapId.getOrElse(0), - endMapId.getOrElse(Int.MaxValue), - startPartition, - endPartition) - new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala index a7390ee..83decb3 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala
@@ -22,6 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,16 +55,27 @@ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { if (isArrowShuffle(handle)) { + val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] + val (blocksByAddress, canEnableBatchFetch) = + if (baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked) { + val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + (res.iter, res.enableBatchFetch) + } else { + val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + (address, true) + } + new BlazeBlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, - endPartition, + blocksByAddress, context, metrics, SparkEnv.get.blockManager, SparkEnv.get.mapOutputTracker, - startMapId = Some(startMapIndex), - endMapId = Some(endMapIndex)) + shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)) } else { sortShuffleManager.getReader( handle,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala index 3fc6649..de3f5f8 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
@@ -21,12 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.BuildLeft +import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.blaze.plan.BroadcastLeft +import org.apache.spark.sql.execution.blaze.plan.BroadcastRight +import org.apache.spark.sql.execution.blaze.plan.BroadcastSide import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.execution.joins.HashedRelationInfo @@ -39,7 +43,7 @@ override val leftKeys: Seq[Expression], override val rightKeys: Seq[Expression], override val joinType: JoinType, - override val condition: Option[Expression]) + broadcastSide: BroadcastSide) extends NativeBroadcastJoinBase( left, right, @@ -47,9 +51,11 @@ leftKeys, rightKeys, joinType, - condition) + broadcastSide) with HashJoin { + override def condition: Option[Expression] = None + override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false) BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -65,7 +71,10 @@ throw new NotImplementedError("NativeBroadcastJoin dose not support codegen") } - override def buildSide: BuildSide = BuildLeft + override def buildSide: BuildSide = broadcastSide match { + case BroadcastLeft => BuildLeft + case BroadcastRight => BuildRight + } override protected def withNewChildrenInternal( newLeft: SparkPlan,
diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java index 31c3b9a..5131478 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java
@@ -27,10 +27,6 @@ /// actual off-heap memory usage is expected to be spark.executor.memoryOverhead * fraction. MEMORY_FRACTION("spark.blaze.memoryFraction", 0.6), - /// translates inequality smj to native. improves performance in most cases, however some - /// issues are found in special cases, like tpcds q72. - SMJ_INEQUALITY_JOIN_ENABLE("spark.blaze.enable.smjInequalityJoin", false), - /// fallbacks to SortMergeJoin when executing BroadcastHashJoin with big broadcasted table. BHJ_FALLBACKS_TO_SMJ_ENABLE("spark.blaze.enable.bhjFallbacksToSmj", true), @@ -64,6 +60,12 @@ /// mininum number of rows to trigger partial aggregate skipping PARTIAL_AGG_SKIPPING_MIN_ROWS("spark.blaze.partialAggSkipping.minRows", BATCH_SIZE.intConf() * 2), + + // parquet enable page filtering + PARQUET_ENABLE_PAGE_FILTERING("spark.blaze.parquet.enable.pageFiltering", false), + + // parqeut enable bloom filter + PARQUET_ENABLE_BLOOM_FILTER("spark.blaze.parquet.enable.bloomFilter", false), ; private String key;
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala index 09bf85e..c24888d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala
@@ -28,7 +28,6 @@ import org.apache.arrow.c.ArrowSchema import org.apache.arrow.c.CDataDictionaryProvider import org.apache.arrow.c.Data -import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.Partition
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala index 9f7eb61..a7ab81d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala
@@ -46,7 +46,8 @@ val convertibleTag: TreeNodeTag[Boolean] = TreeNodeTag("blaze.convertible") val convertStrategyTag: TreeNodeTag[ConvertStrategy] = TreeNodeTag("blaze.convert.strategy") - val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag("blaze.child.ordering.required") + val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag( + "blaze.child.ordering.required") def apply(exec: SparkPlan): Unit = { exec.foreach(_.setTagValue(convertibleTag, true))
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala index 99d0172..4c22319 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
@@ -92,7 +92,7 @@ val enableBhj: Boolean = SparkEnv.get.conf.getBoolean("spark.blaze.enable.bhj", defaultValue = true) val enableBnlj: Boolean = - SparkEnv.get.conf.getBoolean("spark.blaze.enable.bnlj", defaultValue = true) + SparkEnv.get.conf.getBoolean("spark.blaze.enable.bnlj", defaultValue = false) val enableLocalLimit: Boolean = SparkEnv.get.conf.getBoolean("spark.blaze.enable.local.limit", defaultValue = true) val enableGlobalLimit: Boolean = @@ -128,7 +128,9 @@ var newExec = exec.withNewChildren(newChildren) exec.getTagValue(convertibleTag).foreach(newExec.setTagValue(convertibleTag, _)) exec.getTagValue(convertStrategyTag).foreach(newExec.setTagValue(convertStrategyTag, _)) - exec.getTagValue(childOrderingRequiredTag).foreach(newExec.setTagValue(childOrderingRequiredTag, _)) + exec + .getTagValue(childOrderingRequiredTag) + .foreach(newExec.setTagValue(childOrderingRequiredTag, _)) if (!isNeverConvert(newExec)) { newExec = convertSparkPlan(newExec) } @@ -333,45 +335,14 @@ val (leftKeys, rightKeys, joinType, condition, left, right) = (exec.leftKeys, exec.rightKeys, exec.joinType, exec.condition, exec.left, exec.right) logDebug(s"Converting SortMergeJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}") - var nativeLeft = convertToNative(left) - var nativeRight = convertToNative(right) - var modifiedLeftKeys = leftKeys - var modifiedRightKeys = rightKeys - var needPostProject = false - if (leftKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeLeft, leftKeys) - modifiedLeftKeys = keys - nativeLeft = exec - needPostProject = true - } - if (rightKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeRight, rightKeys) - modifiedRightKeys = keys - nativeRight = exec - needPostProject = true - } - - val smjOrig = SortMergeJoinExec( - modifiedLeftKeys, - modifiedRightKeys, + Shims.get.createNativeSortMergeJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + leftKeys, + rightKeys, joinType, - condition, - addRenameColumnsExec(nativeLeft), - addRenameColumnsExec(nativeRight)) - val smj = Shims.get.createNativeSortMergeJoinExec( - smjOrig.left, - smjOrig.right, - smjOrig.leftKeys, - smjOrig.rightKeys, - smjOrig.joinType, - smjOrig.condition) - - if (needPostProject) { - buildPostJoinProject(smj, exec.output) - } else { - smj - } + condition) } def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = { @@ -385,84 +356,33 @@ exec.left, exec.right) logDebug(s"Converting BroadcastHashJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}") - logDebug(s" leftKeys: ${exec.leftKeys}") - logDebug(s" rightKeys: ${exec.rightKeys}") - logDebug(s" joinType: ${exec.joinType}") - logDebug(s" buildSide: ${exec.buildSide}") - logDebug(s" condition: ${exec.condition}") - var (hashed, hashedKeys, nativeProbed, probedKeys) = buildSide match { + logDebug(s" leftKeys: $leftKeys") + logDebug(s" rightKeys: $rightKeys") + logDebug(s" joinType: $joinType") + logDebug(s" buildSide: $buildSide") + logDebug(s" condition: $condition") + assert(condition.isEmpty, "join condition is not supported") + + // verify build side is native + buildSide match { case BuildRight => assert(NativeHelper.isNative(right), "broadcast join build side is not native") - val convertedLeft = convertToNative(left) - (right, rightKeys, convertedLeft, leftKeys) - case BuildLeft => assert(NativeHelper.isNative(left), "broadcast join build side is not native") - val convertedRight = convertToNative(right) - (left, leftKeys, convertedRight, rightKeys) - - case _ => - // scalastyle:off throwerror - throw new NotImplementedError( - "Ignore BroadcastHashJoin with unsupported children structure") } - var modifiedHashedKeys = hashedKeys - var modifiedProbedKeys = probedKeys - var needPostProject = false + Shims.get.createNativeBroadcastJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + exec.outputPartitioning, + leftKeys, + rightKeys, + joinType, + buildSide match { + case BuildLeft => BroadcastLeft + case BuildRight => BroadcastRight + }) - if (hashedKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(hashed, hashedKeys) - modifiedHashedKeys = keys - hashed = exec - needPostProject = true - } - if (probedKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeProbed, probedKeys) - modifiedProbedKeys = keys - nativeProbed = exec - needPostProject = true - } - - val modifiedJoinType = buildSide match { - case BuildLeft => joinType - case BuildRight => - needPostProject = true - val modifiedJoinType = joinType match { // reverse join type - case Inner => Inner - case FullOuter => FullOuter - case LeftOuter => RightOuter - case RightOuter => LeftOuter - case _ => - throw new NotImplementedError( - "BHJ Semi/Anti join with BuildRight is not yet supported") - } - modifiedJoinType - } - - val bhjOrig = BroadcastHashJoinExec( - modifiedHashedKeys, - modifiedProbedKeys, - modifiedJoinType, - BuildLeft, - condition, - addRenameColumnsExec(hashed), - addRenameColumnsExec(nativeProbed)) - - val bhj = Shims.get.createNativeBroadcastJoinExec( - bhjOrig.left, - bhjOrig.right, - bhjOrig.outputPartitioning, - bhjOrig.leftKeys, - bhjOrig.rightKeys, - bhjOrig.joinType, - bhjOrig.condition) - - if (needPostProject) { - buildPostJoinProject(bhj, exec.output) - } else { - bhj - } } catch { case e @ (_: NotImplementedError | _: Exception) => val underlyingBroadcast = exec.buildSide match {
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 1cbfcc8..355444a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala
@@ -52,12 +52,11 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.execution.blaze.plan.Util import org.apache.spark.sql.execution.ScalarSubquery -import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.hive.blaze.HiveUDFUtil import org.apache.spark.sql.hive.blaze.HiveUDFUtil.getFunctionClassName -import org.apache.spark.sql.hive.blaze.HiveUDFUtil.isHiveSimpleUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.AtomicType @@ -1110,6 +1109,7 @@ case FullOuter => pb.JoinType.FULL case LeftSemi => pb.JoinType.SEMI case LeftAnti => pb.JoinType.ANTI + case _: ExistenceJoin => pb.JoinType.EXISTENCE case _ => throw new NotImplementedError(s"unsupported join type: ${joinType}") } }
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala index a8aaad2..fe2bd2c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
@@ -79,7 +79,7 @@ leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase + broadcastSide: BroadcastSide): NativeBroadcastJoinBase def createNativeBroadcastNestedLoopJoinExec( left: SparkPlan,
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala index b78eb08..b103969 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala
@@ -19,15 +19,14 @@ import scala.util.Try /** - * A utility for performing automatic resource management. It can be used to perform an - * operation using resources, after which it releases the resources in reverse order - * of their creation. + * A utility for performing automatic resource management. It can be used to perform an operation + * using resources, after which it releases the resources in reverse order of their creation. * * ==Usage== * - * There are multiple ways to automatically manage resources with `Using`. If you only need - * to manage a single resource, the [[Using.apply `apply`]] method is easiest; it wraps the - * resource opening, operation, and resource releasing in a `Try`. + * There are multiple ways to automatically manage resources with `Using`. If you only need to + * manage a single resource, the [[Using.apply `apply`]] method is easiest; it wraps the resource + * opening, operation, and resource releasing in a `Try`. * * Example: * {{{ @@ -37,9 +36,9 @@ * } * }}} * - * If you need to manage multiple resources, [[Using.Manager$.apply `Using.Manager`]] should - * be used. It allows the managing of arbitrarily many resources, whose creation, use, and - * release are all wrapped in a `Try`. + * If you need to manage multiple resources, [[Using.Manager$.apply `Using.Manager`]] should be + * used. It allows the managing of arbitrarily many resources, whose creation, use, and release + * are all wrapped in a `Try`. * * Example: * {{{ @@ -70,43 +69,44 @@ * * ==Suppression Behavior== * - * If two exceptions are thrown (e.g., by an operation and closing a resource), - * one of them is re-thrown, and the other is - * [[java.lang.Throwable.addSuppressed(Throwable) added to it as a suppressed exception]]. - * If the two exceptions are of different 'severities' (see below), the one of a higher - * severity is re-thrown, and the one of a lower severity is added to it as a suppressed - * exception. If the two exceptions are of the same severity, the one thrown first is - * re-thrown, and the one thrown second is added to it as a suppressed exception. - * If an exception is a [[scala.util.control.ControlThrowable `ControlThrowable`]], or - * if it does not support suppression (see - * [[java.lang.Throwable `Throwable`'s constructor with an `enableSuppression` parameter]]), - * an exception that would have been suppressed is instead discarded. + * If two exceptions are thrown (e.g., by an operation and closing a resource), one of them is + * re-thrown, and the other is + * [[java.lang.Throwable.addSuppressed(Throwable) added to it as a suppressed exception]]. If the + * two exceptions are of different 'severities' (see below), the one of a higher severity is + * re-thrown, and the one of a lower severity is added to it as a suppressed exception. If the two + * exceptions are of the same severity, the one thrown first is re-thrown, and the one thrown + * second is added to it as a suppressed exception. If an exception is a + * [[scala.util.control.ControlThrowable `ControlThrowable`]], or if it does not support + * suppression (see + * [[java.lang.Throwable `Throwable`'s constructor with an `enableSuppression` parameter]]), an + * exception that would have been suppressed is instead discarded. * * Exceptions are ranked from highest to lowest severity as follows: * - `java.lang.VirtualMachineError` * - `java.lang.LinkageError` * - `java.lang.InterruptedException` and `java.lang.ThreadDeath` - * - [[scala.util.control.NonFatal fatal exceptions]], excluding `scala.util.control.ControlThrowable` + * - [[scala.util.control.NonFatal fatal exceptions]], excluding + * `scala.util.control.ControlThrowable` * - `scala.util.control.ControlThrowable` * - all other exceptions * - * When more than two exceptions are thrown, the first two are combined and - * re-thrown as described above, and each successive exception thrown is combined - * as it is thrown. + * When more than two exceptions are thrown, the first two are combined and re-thrown as described + * above, and each successive exception thrown is combined as it is thrown. * - * @define suppressionBehavior See the main doc for [[Using `Using`]] for full details of - * suppression behavior. + * @define suppressionBehavior + * See the main doc for [[Using `Using`]] for full details of suppression behavior. */ object Using { /** - * Performs an operation using a resource, and then releases the resource, - * even if the operation throws an exception. + * Performs an operation using a resource, and then releases the resource, even if the operation + * throws an exception. * * $suppressionBehavior * - * @return a [[Try]] containing an exception if one or more were thrown, - * or the result of the operation if no exceptions were thrown + * @return + * a [[Try]] containing an exception if one or more were thrown, or the result of the + * operation if no exceptions were thrown */ def apply[R: Releasable, A](resource: => R)(f: R => A): Try[A] = Try { Using.resource(resource)(f) @@ -115,20 +115,20 @@ /** * A resource manager. * - * Resources can be registered with the manager by calling [[acquire `acquire`]]; - * such resources will be released in reverse order of their acquisition - * when the manager is closed, regardless of any exceptions thrown - * during use. + * Resources can be registered with the manager by calling [[acquire `acquire`]]; such resources + * will be released in reverse order of their acquisition when the manager is closed, regardless + * of any exceptions thrown during use. * * $suppressionBehavior * - * @note It is recommended for API designers to require an implicit `Manager` - * for the creation of custom resources, and to call `acquire` during those - * resources' construction. Doing so guarantees that the resource ''must'' be - * automatically managed, and makes it impossible to forget to do so. + * @note + * It is recommended for API designers to require an implicit `Manager` for the creation of + * custom resources, and to call `acquire` during those resources' construction. Doing so + * guarantees that the resource ''must'' be automatically managed, and makes it impossible to + * forget to do so. * - * Example: - * {{{ + * Example: + * {{{ * class SafeFileReader(file: File)(implicit manager: Using.Manager) * extends BufferedReader(new FileReader(file)) { * @@ -136,7 +136,7 @@ * * manager.acquire(this) * } - * }}} + * }}} */ final class Manager private { import Manager._ @@ -145,9 +145,8 @@ private[this] var resources: List[Resource[_]] = Nil /** - * Registers the specified resource with this manager, so that - * the resource is released when the manager is closed, and then - * returns the (unmodified) resource. + * Registers the specified resource with this manager, so that the resource is released when + * the manager is closed, and then returns the (unmodified) resource. */ def apply[R: Releasable](resource: R): R = { acquire(resource) @@ -155,8 +154,8 @@ } /** - * Registers the specified resource with this manager, so that - * the resource is released when the manager is closed. + * Registers the specified resource with this manager, so that the resource is released when + * the manager is closed. */ def acquire[R: Releasable](resource: R): Unit = { if (resource == null) throw new NullPointerException("null resource") @@ -194,8 +193,8 @@ object Manager { /** - * Performs an operation using a `Manager`, then closes the `Manager`, - * releasing its resources (in reverse order of acquisition). + * Performs an operation using a `Manager`, then closes the `Manager`, releasing its resources + * (in reverse order of acquisition). * * Example: * {{{ @@ -204,9 +203,8 @@ * } * }}} * - * If using resources which require an implicit `Manager` as a parameter, - * this method should be invoked with an `implicit` modifier before the function - * parameter: + * If using resources which require an implicit `Manager` as a parameter, this method should + * be invoked with an `implicit` modifier before the function parameter: * * Example: * {{{ @@ -217,10 +215,13 @@ * * See the main doc for [[Using `Using`]] for full details of suppression behavior. * - * @param op the operation to perform using the manager - * @tparam A the return type of the operation - * @return a [[Try]] containing an exception if one or more were thrown, - * or the result of the operation if no exceptions were thrown + * @param op + * the operation to perform using the manager + * @tparam A + * the return type of the operation + * @return + * a [[Try]] containing an exception if one or more were thrown, or the result of the + * operation if no exceptions were thrown */ def apply[A](op: Manager => A): Try[A] = Try { (new Manager).manage(op) } @@ -247,18 +248,21 @@ } /** - * Performs an operation using a resource, and then releases the resource, - * even if the operation throws an exception. This method behaves similarly - * to Java's try-with-resources. + * Performs an operation using a resource, and then releases the resource, even if the operation + * throws an exception. This method behaves similarly to Java's try-with-resources. * * $suppressionBehavior * - * @param resource the resource - * @param body the operation to perform with the resource - * @tparam R the type of the resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resource throws + * @param resource + * the resource + * @param body + * the operation to perform with the resource + * @tparam R + * the type of the resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resource throws */ def resource[R, A](resource: R)(body: R => A)(implicit releasable: Releasable[R]): A = { if (resource == null) throw new NullPointerException("null resource") @@ -281,20 +285,26 @@ } /** - * Performs an operation using two resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using two resources, and then releases the resources in reverse order, + * even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, A](resource1: R1, resource2: => R2)( body: (R1, R2) => A): A = @@ -305,22 +315,30 @@ } /** - * Performs an operation using three resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using three resources, and then releases the resources in reverse + * order, even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param resource3 the third resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam R3 the type of the third resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param resource3 + * the third resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam R3 + * the type of the third resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, R3: Releasable, A]( resource1: R1, @@ -335,24 +353,34 @@ } /** - * Performs an operation using four resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using four resources, and then releases the resources in reverse order, + * even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param resource3 the third resource - * @param resource4 the fourth resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam R3 the type of the third resource - * @tparam R4 the type of the fourth resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param resource3 + * the third resource + * @param resource4 + * the fourth resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam R3 + * the type of the third resource + * @tparam R4 + * the type of the fourth resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, R3: Releasable, R4: Releasable, A]( resource1: R1, @@ -372,17 +400,18 @@ /** * A typeclass describing how to release a particular type of resource. * - * A resource is anything which needs to be released, closed, or otherwise cleaned up - * in some way after it is finished being used, and for which waiting for the object's - * garbage collection to be cleaned up would be unacceptable. For example, an instance of - * [[java.io.OutputStream]] would be considered a resource, because it is important to close - * the stream after it is finished being used. + * A resource is anything which needs to be released, closed, or otherwise cleaned up in some + * way after it is finished being used, and for which waiting for the object's garbage + * collection to be cleaned up would be unacceptable. For example, an instance of + * [[java.io.OutputStream]] would be considered a resource, because it is important to close the + * stream after it is finished being used. * - * An instance of `Releasable` is needed in order to automatically manage a resource - * with [[Using `Using`]]. An implicit instance is provided for all types extending + * An instance of `Releasable` is needed in order to automatically manage a resource with + * [[Using `Using`]]. An implicit instance is provided for all types extending * [[java.lang.AutoCloseable]]. * - * @tparam R the type of the resource + * @tparam R + * the type of the resource */ trait Releasable[-R] {
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala index 0522623..852e533 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala
@@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.blaze.arrowio.ArrowFFIExportIterator import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.OneToOneDependency -import org.apache.spark.sql.blaze.BlazeConf import org.blaze.protobuf.FFIReaderExecNode import org.blaze.protobuf.PhysicalPlanNode import org.blaze.protobuf.Schema
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala index 5525fb4..2947ca9 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala
@@ -24,22 +24,19 @@ import java.util.concurrent.TimeoutException import java.util.concurrent.TimeUnit -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ import scala.collection.immutable.SortedMap import scala.concurrent.Promise +import org.apache.commons.lang3.reflect.MethodUtils import org.apache.spark.OneToOneDependency import org.apache.spark.Partition import org.apache.spark.SparkException import org.apache.spark.TaskContext import org.apache.spark.broadcast -import org.blaze.{protobuf => pb} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.blaze.BlazeCallNativeWrapper -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.JniBridge import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters @@ -49,7 +46,10 @@ import org.apache.spark.sql.blaze.Shims import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.plans.physical.BroadcastPartitioning import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode @@ -63,6 +63,8 @@ import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.BinaryType +import org.blaze.{protobuf => pb} abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val child: SparkPlan) extends BroadcastExchangeLike @@ -71,10 +73,15 @@ override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + def broadcastMode: BroadcastMode = this.mode + + protected val hashMapOutput: Seq[Attribute] = output + .map(_.withNullability(true)) :+ AttributeReference("~TABLE", BinaryType, nullable = true)() + protected val nativeSchema: pb.Schema = Util.getNativeSchema(output) + protected val nativeHashMapSchema: pb.Schema = Util.getNativeSchema(hashMapOutput) def getRunId: UUID - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper .getDefaultNativeMetrics(sparkContext) @@ -93,9 +100,6 @@ override def doPrepare(): Unit = { // Materialize the future. relationFuture - relationFuture - relationFuture - relationFuture } override def doExecuteBroadcast[T](): Broadcast[T] = { @@ -103,17 +107,31 @@ override def index: Int = 0 } val broadcastReadNativePlan = doExecuteNative().nativePlan(singlePartition, null) - val rows = NativeHelper.executeNativePlan( + val rowsIter = NativeHelper.executeNativePlan( broadcastReadNativePlan, MetricNode(Map(), Nil, None), singlePartition, None) - val v = mode.transform(rows.toArray) + val pruneKeyField = new InterpretedUnsafeProjection(output + .zipWithIndex + .map(v => BoundReference(v._2, v._1.dataType, v._1.nullable)) + .toArray) + val dataRows = rowsIter + .map(pruneKeyField) + .map(_.copy()) + .toArray + + val broadcast = relationFuture.get // bloadcast must be resolved + val v = mode.transform(dataRows) val dummyBroadcasted = new Broadcast[Any](-1) { override protected def getValue(): Any = v - override protected def doUnpersist(blocking: Boolean): Unit = {} - override protected def doDestroy(blocking: Boolean): Unit = {} + override protected def doUnpersist(blocking: Boolean): Unit = { + MethodUtils.invokeMethod(broadcast, true, "doUnpersist", Array(blocking)) + } + override protected def doDestroy(blocking: Boolean): Unit = { + MethodUtils.invokeMethod(broadcast, true, "doDestroy", Array(blocking)) + } } dummyBroadcasted.asInstanceOf[Broadcast[T]] } @@ -154,13 +172,14 @@ Channels.newChannel(new ByteArrayInputStream(bytes)) }) } + JniBridge.resourcesMap.put(resourceId, () => provideIpcIterator()) pb.PhysicalPlanNode .newBuilder() .setIpcReader( pb.IpcReaderExecNode .newBuilder() - .setSchema(nativeSchema) + .setSchema(nativeHashMapSchema) .setNumPartitions(1) .setIpcProviderResourceId(resourceId) .build()) @@ -267,39 +286,21 @@ keys: Seq[Expression], nativeSchema: pb.Schema): Array[Array[Byte]] = { - if (!BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || keys.isEmpty) { - return collectedData // no need to sort data in driver side - } - val readerIpcProviderResourceId = s"BuildBroadcastDataReader:${UUID.randomUUID()}" val readerExec = pb.IpcReaderExecNode .newBuilder() .setSchema(nativeSchema) .setIpcProviderResourceId(readerIpcProviderResourceId) - val sortExec = pb.SortExecNode + val buildHashMapExec = pb.BroadcastJoinBuildHashMapExecNode .newBuilder() .setInput(pb.PhysicalPlanNode.newBuilder().setIpcReader(readerExec)) - .addAllExpr( - keys - .map(key => { - pb.PhysicalExprNode - .newBuilder() - .setSort( - pb.PhysicalSortExprNode - .newBuilder() - .setExpr(NativeConverters.convertExpr(key)) - .setAsc(true) - .setNullsFirst(true) - .build()) - .build() - }) - .asJava) + .addAllKeys(keys.map(key => NativeConverters.convertExpr(key)).asJava) val writerIpcProviderResourceId = s"BuildBroadcastDataWriter:${UUID.randomUUID()}" val writerExec = pb.IpcWriterExecNode .newBuilder() - .setInput(pb.PhysicalPlanNode.newBuilder().setSort(sortExec)) + .setInput(pb.PhysicalPlanNode.newBuilder().setBroadcastJoinBuildHashMap(buildHashMapExec)) .setIpcConsumerResourceId(writerIpcProviderResourceId) // build native sorter
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala index ec13b8f..d8f6e62 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala
@@ -20,21 +20,24 @@ import org.apache.spark.OneToOneDependency import org.apache.spark.Partition -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.blaze.NativeSupports +import org.apache.spark.sql.blaze.Shims +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.BinaryExecNode +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.types.LongType import org.blaze.{protobuf => pb} +import org.blaze.protobuf.JoinOn abstract class NativeBroadcastJoinBase( override val left: SparkPlan, @@ -43,82 +46,112 @@ leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]) + broadcastSide: BroadcastSide) extends BinaryExecNode with NativeSupports { - assert( - (joinType != LeftSemi && joinType != LeftAnti) || condition.isEmpty, - "Semi/Anti join with filter is not supported yet") - - assert( - !BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE - .booleanConf() || condition.isEmpty, - "Join filter is not supported when BhjFallbacksToSmj and SmjInequalityJoin both enabled") - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper .getDefaultNativeMetrics(sparkContext) .toSeq: _*) + private val isLongHashRelation = { + val baseBroadcast = broadcastSide match { + case BroadcastLeft => Shims.get.getUnderlyingBroadcast(left) + case BroadcastRight => Shims.get.getUnderlyingBroadcast(right) + } + val mode = baseBroadcast match { + case b: BroadcastExchangeExec => b.mode + case b: NativeBroadcastExchangeBase => b.broadcastMode + } + mode match { + case HashedRelationBroadcastMode(Seq(key), _) if key.dataType == LongType => true + case _ => false + } + } + + private def nativeSchema = Util.getNativeSchema(output) + private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) => - val leftColumn = NativeConverters.convertExpr(leftKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"BHJ leftKey is not column: ${leftKey}") - case column => column + val leftKeyExpr = leftKey match { + case k if !isLongHashRelation || k.dataType == LongType => k + case k => Cast(k, LongType) } - val rightColumn = NativeConverters.convertExpr(rightKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"BHJ rightKey is not column: ${rightKey}") - case column => column + val rightKeyExpr = rightKey match { + case k if !isLongHashRelation || k.dataType == LongType => k + case k => Cast(k, LongType) } - pb.JoinOn + JoinOn .newBuilder() - .setLeft(leftColumn) - .setRight(rightColumn) + .setLeft(NativeConverters.convertExpr(leftKeyExpr)) + .setRight(NativeConverters.convertExpr(rightKeyExpr)) .build() } private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) + private def nativeBroadcastSide = broadcastSide match { + case BroadcastLeft => pb.JoinSide.LEFT_SIDE + case BroadcastRight => pb.JoinSide.RIGHT_SIDE + } // check whether native converting is supported + nativeSchema nativeJoinType - nativeJoinFilter + nativeJoinOn + nativeBroadcastSide override def doExecuteNative(): NativeRDD = { val leftRDD = NativeHelper.executeNative(left) val rightRDD = NativeHelper.executeNative(right) val nativeMetrics = MetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil) + val nativeSchema = this.nativeSchema val nativeJoinType = this.nativeJoinType val nativeJoinOn = this.nativeJoinOn - val nativeJoinFilter = this.nativeJoinFilter - val partitions = rightRDD.partitions + + val (probedRDD, builtRDD) = broadcastSide match { + case BroadcastLeft => (rightRDD, leftRDD) + case BroadcastRight => (leftRDD, rightRDD) + } new NativeRDD( sparkContext, nativeMetrics, - partitions, - rddDependencies = new OneToOneDependency(rightRDD) :: Nil, - rightRDD.isShuffleReadFull, + probedRDD.partitions, + rddDependencies = new OneToOneDependency(probedRDD) :: Nil, + probedRDD.isShuffleReadFull, (partition, context) => { val partition0 = new Partition() { override def index: Int = 0 } - val leftChild = leftRDD.nativePlan(partition0, context) - val rightChild = rightRDD.nativePlan(rightRDD.partitions(partition.index), context) + val (leftChild, rightChild) = broadcastSide match { + case BroadcastLeft => ( + leftRDD.nativePlan(partition0, context), + rightRDD.nativePlan(rightRDD.partitions(partition.index), context), + ) + case BroadcastRight => ( + leftRDD.nativePlan(leftRDD.partitions(partition.index), context), + rightRDD.nativePlan(partition0, context), + ) + } + val cachedBuildHashMapId = s"bhm_stage${context.stageId}_rdd${builtRDD.id}" + val broadcastJoinExec = pb.BroadcastJoinExecNode .newBuilder() + .setSchema(nativeSchema) .setLeft(leftChild) .setRight(rightChild) .setJoinType(nativeJoinType) + .setBroadcastSide(nativeBroadcastSide) + .setCachedBuildHashMapId(cachedBuildHashMapId) .addAllOn(nativeJoinOn.asJava) - nativeJoinFilter.foreach(joinFilter => broadcastJoinExec.setJoinFilter(joinFilter)) pb.PhysicalPlanNode.newBuilder().setBroadcastJoin(broadcastJoinExec).build() }, friendlyName = "NativeRDD.BroadcastJoin") } } + +class BroadcastSide {} +case object BroadcastLeft extends BroadcastSide {} +case object BroadcastRight extends BroadcastSide {}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala index 2349cc9..dc0e371 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala
@@ -22,7 +22,6 @@ import org.apache.spark.OneToOneDependency import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters -import org.apache.spark.sql.blaze.NativeConverters.convertExprWithFallback import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.blaze.NativeSupports
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala index 52efbcd..831211b 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.OneToOneDependency -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper @@ -52,13 +51,7 @@ extends BinaryExecNode with NativeSupports { - assert( - (joinType != LeftSemi && joinType != LeftAnti) || condition.isEmpty, - "Semi/Anti join with filter is not supported yet") - - assert( - BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE.booleanConf() || condition.isEmpty, - "inequality sort-merge join is not enabled") + assert(condition.isEmpty, "inequality join is not supported") override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper @@ -81,21 +74,15 @@ keys.map(SortOrder(_, Ascending)) } + private def nativeSchema = Util.getNativeSchema(output) + private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) => - val leftColumn = NativeConverters.convertExpr(leftKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"SMJ leftKey is not column: ${leftKey}") - case column => column - } - val rightColumn = NativeConverters.convertExpr(rightKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"SMJ rightKey is not column: ${rightKey}") - case column => column - } + val leftKeyExpr = NativeConverters.convertExpr(leftKey) + val rightKeyExpr = NativeConverters.convertExpr(rightKey) JoinOn .newBuilder() - .setLeft(leftColumn) - .setRight(rightColumn) + .setLeft(leftKeyExpr) + .setRight(rightKeyExpr) .build() } @@ -109,14 +96,11 @@ private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) - // check whether native converting is supported + nativeSchema nativeSortOptions nativeJoinOn nativeJoinType - nativeJoinFilter override def doExecuteNative(): NativeRDD = { val leftRDD = NativeHelper.executeNative(left) @@ -125,7 +109,6 @@ val nativeSortOptions = this.nativeSortOptions val nativeJoinOn = this.nativeJoinOn val nativeJoinType = this.nativeJoinType - val nativeJoinFilter = this.nativeJoinFilter val partitions = if (joinType != RightOuter) { leftRDD.partitions @@ -161,13 +144,12 @@ val sortMergeJoinExec = SortMergeJoinExecNode .newBuilder() + .setSchema(nativeSchema) .setLeft(leftChild) .setRight(rightChild) .setJoinType(nativeJoinType) .addAllOn(nativeJoinOn.asJava) .addAllSortOptions(nativeSortOptions.asJava) - - nativeJoinFilter.foreach(joinFilter => sortMergeJoinExec.setJoinFilter(joinFilter)) PhysicalPlanNode.newBuilder().setSortMergeJoin(sortMergeJoinExec).build() }, friendlyName = "NativeRDD.SortMergeJoin")