blob: 7bf28cdb54faabb1e13260b620e4a4d1ec871f2e [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::{fmt::Formatter, sync::Arc};
use arrow_schema::SchemaRef;
use datafusion_common::{project_schema, DataFusionError, JoinSide, Result};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{JoinType, Operator};
use datafusion_physical_expr::{
equivalence::{join_equivalence_properties, ProjectionMapping},
expressions::{BinaryExpr, Column},
PhysicalExpr,
};
use datafusion_physical_plan::{
execution_plan::EmissionType,
joins::utils::{build_join_schema, check_join_is_valid, ColumnIndex, JoinFilter},
metrics::{ExecutionPlanMetricsSet, MetricsSet},
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
PlanProperties,
};
use parking_lot::Mutex;
use crate::{
build_index::build_index,
index::SpatialIndex,
spatial_predicate::{KNNPredicate, SpatialPredicate},
stream::{SpatialJoinProbeMetrics, SpatialJoinStream},
utils::join_utils::{asymmetric_join_output_partitioning, boundedness_from_children},
utils::once_fut::OnceAsync,
SedonaOptions,
};
/// Type alias for build and probe execution plans
type BuildProbePlans<'a> = (&'a Arc<dyn ExecutionPlan>, &'a Arc<dyn ExecutionPlan>);
/// Extract equality join conditions from a JoinFilter
/// Returns column pairs that represent equality conditions as PhysicalExprs
fn extract_equality_conditions(
filter: &JoinFilter,
) -> Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> {
let mut equalities = Vec::new();
if let Some(binary_expr) = filter.expression().as_any().downcast_ref::<BinaryExpr>() {
if binary_expr.op() == &Operator::Eq {
// Check if both sides are column references
if let (Some(_left_col), Some(_right_col)) = (
binary_expr.left().as_any().downcast_ref::<Column>(),
binary_expr.right().as_any().downcast_ref::<Column>(),
) {
equalities.push((binary_expr.left().clone(), binary_expr.right().clone()));
}
}
}
equalities
}
/// Determine the correct build/probe execution plan assignment for KNN joins.
///
/// For KNN joins, we need to determine which execution plan should be used as the build side
/// (indexed candidates) and which should be the probe side (queries that search the index).
///
/// The key insight is that the KNNPredicate expressions have already been correctly reprojected
/// by the optimizer, so we can use the join schema structure to determine the mapping:
/// - KNNPredicate.left should always be the probe side (queries)
/// - KNNPredicate.right should always be the build side (candidates)
///
/// We determine which execution plan corresponds to probe/build by analyzing the column indices
/// in the context of the overall join schema structure.
fn determine_knn_build_probe_plans<'a>(
knn_pred: &KNNPredicate,
left_plan: &'a Arc<dyn ExecutionPlan>,
right_plan: &'a Arc<dyn ExecutionPlan>,
_join_schema: &SchemaRef,
) -> Result<BuildProbePlans<'a>> {
// Use the probe_side information from the optimizer to determine build/probe assignment
match knn_pred.probe_side {
JoinSide::Left => Ok((right_plan, left_plan)),
JoinSide::Right => Ok((left_plan, right_plan)),
JoinSide::None => Err(DataFusionError::Internal(
"KNN join requires explicit probe_side designation".to_string(),
)),
}
}
/// Physical execution plan for performing spatial joins between two tables. It uses a spatial
/// index to speed up the join operation.
///
/// ## Algorithm Overview
///
/// The spatial join execution follows a hash-join-like pattern:
/// 1. **Build Phase**: The left (smaller) table geometries are indexed using a spatial index
/// 2. **Probe Phase**: Each geometry from the right table is used to query the spatial index
/// 3. **Refinement**: Candidate pairs from the index are refined using exact spatial predicates
/// 4. **Output**: Matching pairs are combined according to the specified join type
#[derive(Debug)]
pub struct SpatialJoinExec {
/// left (build) side which gets hashed
pub left: Arc<dyn ExecutionPlan>,
/// right (probe) side which are filtered by the hash table
pub right: Arc<dyn ExecutionPlan>,
/// Primary spatial join condition (the expression in the ON clause of the join)
pub on: SpatialPredicate,
/// Additional filters which are applied while finding matching rows. It could contain part of
/// the ON clause, or expressions in the WHERE clause.
pub filter: Option<JoinFilter>,
/// How the join is performed (`OUTER`, `INNER`, etc)
pub join_type: JoinType,
/// The schema after join. Please be careful when using this schema,
/// if there is a projection, the schema isn't the same as the output schema.
join_schema: SchemaRef,
metrics: ExecutionPlanMetricsSet,
/// The projection indices of the columns in the output schema of join
projection: Option<Vec<usize>>,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
/// Cache holding plan properties like equivalences, output partitioning etc.
cache: PlanProperties,
/// Once future for building the spatial index.
/// This futures run only once before the spatial index probing phase. It can also be disposed
/// by the last finished stream so that the spatial index does not have to live as long as
/// `SpatialJoinExec`.
once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
/// Indicates if this SpatialJoin was converted from a HashJoin
/// When true, we preserve HashJoin's equivalence properties and partitioning
converted_from_hash_join: bool,
}
impl SpatialJoinExec {
/// Try to create a new [`SpatialJoinExec`]
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: SpatialPredicate,
filter: Option<JoinFilter>,
join_type: &JoinType,
projection: Option<Vec<usize>>,
) -> Result<Self> {
Self::try_new_with_options(left, right, on, filter, join_type, projection, false)
}
/// Create a new SpatialJoinExec with additional options
pub fn try_new_with_options(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: SpatialPredicate,
filter: Option<JoinFilter>,
join_type: &JoinType,
projection: Option<Vec<usize>>,
converted_from_hash_join: bool,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &[])?;
let (join_schema, column_indices) =
build_join_schema(&left_schema, &right_schema, join_type);
let join_schema = Arc::new(join_schema);
let cache = Self::compute_properties(
&left,
&right,
Arc::clone(&join_schema),
*join_type,
projection.as_ref(),
filter.as_ref(),
converted_from_hash_join,
)?;
Ok(SpatialJoinExec {
left,
right,
on,
filter,
join_type: *join_type,
join_schema,
column_indices,
projection,
metrics: Default::default(),
cache,
once_async_spatial_index: Arc::new(Mutex::new(None)),
converted_from_hash_join,
})
}
/// How the join is performed
pub fn join_type(&self) -> &JoinType {
&self.join_type
}
/// Returns a vector indicating whether the left and right inputs maintain their order.
/// The first element corresponds to the left input, and the second to the right.
///
/// The left (build-side) input's order may change, but the right (probe-side) input's
/// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
///
/// Maintaining the right input's order helps optimize the nodes down the pipeline
/// (See [`ExecutionPlan::maintains_input_order`]).
///
/// This is a separate method because it is also called when computing properties, before
/// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as
/// opposed to `Self`, for the same reason.
fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
vec![
false,
matches!(
join_type,
JoinType::Inner | JoinType::Right | JoinType::RightAnti | JoinType::RightSemi
),
]
}
/// Does this join has a projection on the joined columns
pub fn contains_projection(&self) -> bool {
self.projection.is_some()
}
/// This function creates the cache object that stores the plan properties such as schema,
/// equivalence properties, ordering, partitioning, etc.
///
/// NOTICE: The implementation of this function should be identical to the one in
/// [`datafusion_physical_plan::physical_plan::join::NestedLoopJoinExec::compute_properties`].
/// This is because SpatialJoinExec is transformed from NestedLoopJoinExec in physical plan
/// optimization phase. If the properties are not the same, the plan will be incorrect.
///
/// When converted from HashJoin, we preserve HashJoin's equivalence properties by extracting
/// equality conditions from the filter.
fn compute_properties(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
join_type: JoinType,
projection: Option<&Vec<usize>>,
filter: Option<&JoinFilter>,
converted_from_hash_join: bool,
) -> Result<PlanProperties> {
// Extract equality conditions from filter if this was converted from HashJoin
let on_columns = if converted_from_hash_join {
filter.map_or(vec![], extract_equality_conditions)
} else {
vec![]
};
let mut eq_properties = join_equivalence_properties(
left.equivalence_properties().clone(),
right.equivalence_properties().clone(),
&join_type,
Arc::clone(&schema),
&[false, false],
None,
// Pass extracted equality conditions to preserve equivalences
&on_columns,
);
// Use symmetric partitioning (like HashJoin) when converted from HashJoin
// Otherwise use asymmetric partitioning (like NestedLoopJoin)
let mut output_partitioning = if converted_from_hash_join {
// Replicate HashJoin's symmetric partitioning logic
// HashJoin preserves partitioning from both sides for inner joins
// and from one side for outer joins
match join_type {
JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
left.output_partitioning().clone()
}
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
right.output_partitioning().clone()
}
JoinType::Full => {
// For full outer join, we can't preserve partitioning
Partitioning::UnknownPartitioning(left.output_partitioning().partition_count())
}
_ => asymmetric_join_output_partitioning(left, right, &join_type),
}
} else {
asymmetric_join_output_partitioning(left, right, &join_type)
};
if let Some(projection) = projection {
// construct a map from the input expressions to the output expression of the Projection
let projection_mapping = ProjectionMapping::from_indices(projection, &schema)?;
let out_schema = project_schema(&schema, Some(projection))?;
let eq_props = eq_properties?;
output_partitioning = output_partitioning.project(&projection_mapping, &eq_props);
eq_properties = Ok(eq_props.project(&projection_mapping, out_schema));
}
let emission_type = if left.boundedness().is_unbounded() {
EmissionType::Final
} else if right.pipeline_behavior() == EmissionType::Incremental {
match join_type {
// If we only need to generate matched rows from the probe side,
// we can emit rows incrementally.
JoinType::Inner
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::Right
| JoinType::RightAnti => EmissionType::Incremental,
// If we need to generate unmatched rows from the *build side*,
// we need to emit them at the end.
JoinType::Left
| JoinType::LeftAnti
| JoinType::LeftMark
| JoinType::RightMark
| JoinType::Full => EmissionType::Both,
}
} else {
right.pipeline_behavior()
};
Ok(PlanProperties::new(
eq_properties?,
output_partitioning,
emission_type,
boundedness_from_children([left, right]),
))
}
}
impl DisplayAs for SpatialJoinExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let display_on = format!(", on={}", self.on);
let display_filter = self.filter.as_ref().map_or_else(
|| "".to_string(),
|f| format!(", filter={}", f.expression()),
);
let display_projections = if self.contains_projection() {
format!(
", projection=[{}]",
self.projection
.as_ref()
.unwrap()
.iter()
.map(|index| format!(
"{}@{}",
self.join_schema.fields().get(*index).unwrap().name(),
index
))
.collect::<Vec<_>>()
.join(", ")
)
} else {
"".to_string()
};
write!(
f,
"SpatialJoinExec: join_type={:?}{}{}{}",
self.join_type, display_on, display_filter, display_projections
)
}
DisplayFormatType::TreeRender => {
if *self.join_type() != JoinType::Inner {
writeln!(f, "join_type={:?}", self.join_type)
} else {
Ok(())
}
}
}
}
}
impl ExecutionPlan for SpatialJoinExec {
fn name(&self) -> &str {
"SpatialJoinExec"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order(self.join_type)
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(SpatialJoinExec {
left: children[0].clone(),
right: children[1].clone(),
on: self.on.clone(),
filter: self.filter.clone(),
join_type: self.join_type,
join_schema: self.join_schema.clone(),
column_indices: self.column_indices.clone(),
projection: self.projection.clone(),
metrics: Default::default(),
cache: self.cache.clone(),
once_async_spatial_index: Arc::new(Mutex::new(None)),
converted_from_hash_join: self.converted_from_hash_join,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
match &self.on {
SpatialPredicate::KNearestNeighbors(_) => self.execute_knn(partition, context),
_ => {
// Regular spatial join logic - standard left=build, right=probe semantics
let session_config = context.session_config();
let target_output_batch_size = session_config.options().execution.batch_size;
let sedona_options = session_config
.options()
.extensions
.get::<SedonaOptions>()
.cloned()
.unwrap_or_default();
// Regular join semantics: left is build, right is probe
let (build_plan, probe_plan) = (&self.left, &self.right);
// Build the spatial index
let once_fut_spatial_index = {
let mut once_async = self.once_async_spatial_index.lock();
once_async
.get_or_insert(OnceAsync::default())
.try_once(|| {
let build_side = build_plan;
let num_partitions = build_side.output_partitioning().partition_count();
let mut build_streams = Vec::with_capacity(num_partitions);
for k in 0..num_partitions {
let stream = build_side.execute(k, Arc::clone(&context))?;
build_streams.push(stream);
}
let probe_thread_count =
self.right.output_partitioning().partition_count();
Ok(build_index(
Arc::clone(&context),
build_side.schema(),
build_streams,
self.on.clone(),
self.join_type,
probe_thread_count,
self.metrics.clone(),
))
})?
};
// Column indices for regular joins - no swapping needed
let column_indices_after_projection = match &self.projection {
Some(projection) => projection
.iter()
.map(|i| self.column_indices[*i].clone())
.collect(),
None => self.column_indices.clone(),
};
let join_metrics = SpatialJoinProbeMetrics::new(partition, &self.metrics);
let probe_stream = probe_plan.execute(partition, Arc::clone(&context))?;
// For regular joins: probe is right side (index 1)
let probe_side_ordered =
self.maintains_input_order()[1] && self.right.output_ordering().is_some();
Ok(Box::pin(SpatialJoinStream::new(
self.schema(),
&self.on,
self.filter.clone(),
self.join_type,
probe_stream,
column_indices_after_projection,
probe_side_ordered,
join_metrics,
sedona_options.spatial_join,
target_output_batch_size,
once_fut_spatial_index,
Arc::clone(&self.once_async_spatial_index),
)))
}
}
}
}
impl SpatialJoinExec {
/// Execute KNN (K-Nearest Neighbors) spatial join with specialized logic for asymmetric KNN semantics
fn execute_knn(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let session_config = context.session_config();
let target_output_batch_size = session_config.options().execution.batch_size;
let sedona_options = session_config
.options()
.extensions
.get::<SedonaOptions>()
.cloned()
.unwrap_or_default();
// Extract KNN predicate for type safety
let knn_pred = match &self.on {
SpatialPredicate::KNearestNeighbors(knn_pred) => knn_pred,
_ => unreachable!("execute_knn called with non-KNN predicate"),
};
// Determine which execution plan should be build vs probe using join schema analysis
let (build_plan, probe_plan) =
determine_knn_build_probe_plans(knn_pred, &self.left, &self.right, &self.join_schema)?;
// Determine if probe plan is the left execution plan (for column index swapping logic)
let actual_probe_plan_is_left = std::ptr::eq(probe_plan.as_ref(), self.left.as_ref());
// Build the spatial index
let once_fut_spatial_index = {
let mut once_async = self.once_async_spatial_index.lock();
once_async
.get_or_insert(OnceAsync::default())
.try_once(|| {
let build_side = build_plan;
let num_partitions = build_side.output_partitioning().partition_count();
let mut build_streams = Vec::with_capacity(num_partitions);
for k in 0..num_partitions {
let stream = build_side.execute(k, Arc::clone(&context))?;
build_streams.push(stream);
}
let probe_thread_count = probe_plan.output_partitioning().partition_count();
Ok(build_index(
Arc::clone(&context),
build_side.schema(),
build_streams,
self.on.clone(),
self.join_type,
probe_thread_count,
self.metrics.clone(),
))
})?
};
// Handle column indices for KNN - need to swap if we swapped execution plans
let mut column_indices_after_projection = match &self.projection {
Some(projection) => projection
.iter()
.map(|i| self.column_indices[*i].clone())
.collect(),
None => self.column_indices.clone(),
};
// If we swapped execution plans for KNN, we need to swap the column indices too
if !actual_probe_plan_is_left {
for col_idx in &mut column_indices_after_projection {
match col_idx.side {
datafusion_common::JoinSide::Left => {
col_idx.side = datafusion_common::JoinSide::Right
}
datafusion_common::JoinSide::Right => {
col_idx.side = datafusion_common::JoinSide::Left
}
datafusion_common::JoinSide::None => {} // No change needed
}
}
}
let join_metrics = SpatialJoinProbeMetrics::new(partition, &self.metrics);
let probe_stream = probe_plan.execute(partition, Arc::clone(&context))?;
// Determine if probe side ordering is maintained for KNN
let probe_side_ordered = if actual_probe_plan_is_left {
// Actual probe is left plan
self.maintains_input_order()[0] && self.left.output_ordering().is_some()
} else {
// Actual probe is right plan
self.maintains_input_order()[1] && self.right.output_ordering().is_some()
};
Ok(Box::pin(SpatialJoinStream::new(
self.schema(),
&self.on,
self.filter.clone(),
self.join_type,
probe_stream,
column_indices_after_projection,
probe_side_ordered,
join_metrics,
sedona_options.spatial_join,
target_output_batch_size,
once_fut_spatial_index,
Arc::clone(&self.once_async_spatial_index),
)))
}
}
#[cfg(test)]
mod tests {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use datafusion::{
catalog::{MemTable, TableProvider},
execution::SessionStateBuilder,
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use geo_types::{Coord, Rect};
use rstest::rstest;
use sedona_geometry::types::GeometryTypeId;
use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
use sedona_testing::datagen::RandomPartitionedDataBuilder;
use tokio::sync::OnceCell;
use crate::register_spatial_join_optimizer;
use sedona_common::{
option::{add_sedona_option_extension, ExecutionMode, SpatialJoinOptions},
SpatialLibrary,
};
use super::*;
type TestPartitions = (SchemaRef, Vec<Vec<RecordBatch>>);
/// Creates standard test data with left (Polygon) and right (Point) partitions
fn create_default_test_data() -> Result<(TestPartitions, TestPartitions)> {
create_test_data_with_size_range((1.0, 10.0), WKB_GEOMETRY)
}
/// Creates test data with custom size range
fn create_test_data_with_size_range(
size_range: (f64, f64),
sedona_type: SedonaType,
) -> Result<(TestPartitions, TestPartitions)> {
let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 100.0 });
let left_data = RandomPartitionedDataBuilder::new()
.seed(1)
.num_partitions(2)
.batches_per_partition(2)
.rows_per_batch(30)
.geometry_type(GeometryTypeId::Polygon)
.sedona_type(sedona_type.clone())
.bounds(bounds)
.size_range(size_range)
.null_rate(0.1)
.build()?;
let right_data = RandomPartitionedDataBuilder::new()
.seed(2)
.num_partitions(4)
.batches_per_partition(4)
.rows_per_batch(30)
.geometry_type(GeometryTypeId::Point)
.sedona_type(sedona_type)
.bounds(bounds)
.size_range(size_range)
.null_rate(0.1)
.build()?;
Ok((left_data, right_data))
}
/// Creates test data with empty partitions inserted at beginning and end
fn create_test_data_with_empty_partitions() -> Result<(TestPartitions, TestPartitions)> {
let (mut left_data, mut right_data) = create_default_test_data()?;
// Add empty partitions
left_data.1.insert(0, vec![]);
left_data.1.push(vec![]);
right_data.1.insert(0, vec![]);
right_data.1.push(vec![]);
Ok((left_data, right_data))
}
fn setup_context(
options: Option<SpatialJoinOptions>,
batch_size: usize,
) -> Result<SessionContext> {
let mut session_config = SessionConfig::from_env()?
.with_information_schema(true)
.with_batch_size(batch_size);
session_config = add_sedona_option_extension(session_config);
let mut state_builder = SessionStateBuilder::new();
if let Some(options) = options {
state_builder = register_spatial_join_optimizer(state_builder);
let opts = session_config
.options_mut()
.extensions
.get_mut::<SedonaOptions>()
.unwrap();
opts.spatial_join = options;
}
let state = state_builder.with_config(session_config).build();
let ctx = SessionContext::new_with_state(state);
let mut function_set = sedona_functions::register::default_function_set();
let scalar_kernels = sedona_geos::register::scalar_kernels();
function_set.scalar_udfs().for_each(|udf| {
ctx.register_udf(udf.clone().into());
});
for (name, kernel) in scalar_kernels.into_iter() {
let udf = function_set.add_scalar_udf_kernel(name, kernel)?;
ctx.register_udf(udf.clone().into());
}
Ok(ctx)
}
#[tokio::test]
async fn test_empty_data() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("dist", DataType::Float64, false),
WKB_GEOMETRY.to_storage_field("geometry", true).unwrap(),
]));
let test_data_vec = vec![vec![vec![]], vec![vec![], vec![]]];
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareNone,
..Default::default()
};
let ctx = setup_context(Some(options.clone()), 10)?;
for test_data in test_data_vec {
let left_partitions = test_data.clone();
let right_partitions = test_data;
let mem_table_left: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
Arc::clone(&schema),
left_partitions.clone(),
)?);
let mem_table_right: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
Arc::clone(&schema),
right_partitions.clone(),
)?);
ctx.deregister_table("L")?;
ctx.deregister_table("R")?;
ctx.register_table("L", Arc::clone(&mem_table_left))?;
ctx.register_table("R", Arc::clone(&mem_table_right))?;
let sql = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
let df = ctx.sql(sql).await?;
let result_batches = df.collect().await?;
for result_batch in result_batches {
assert_eq!(result_batch.num_rows(), 0);
}
}
Ok(())
}
// Shared test data and expected results - computed only once across all parameterized test cases
// Using tokio::sync::OnceCell for async lazy initialization to avoid recomputing expensive
// test data generation and nested loop join results for each test parameter combination
static TEST_DATA: OnceCell<(TestPartitions, TestPartitions)> = OnceCell::const_new();
static RANGE_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = OnceCell::const_new();
static DIST_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = OnceCell::const_new();
const RANGE_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
const RANGE_JOIN_SQL2: &str =
"SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY L.id, R.id";
const RANGE_JOIN_SQLS: &[&str] = &[RANGE_JOIN_SQL1, RANGE_JOIN_SQL2];
const DIST_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_Distance(L.geometry, R.geometry) < 1.0 ORDER BY l_id, r_id";
const DIST_JOIN_SQL2: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_Distance(L.geometry, R.geometry) < L.dist / 10.0 ORDER BY l_id, r_id";
const DIST_JOIN_SQL3: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_Distance(L.geometry, R.geometry) < R.dist / 10.0 ORDER BY l_id, r_id";
const DIST_JOIN_SQL4: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_DWithin(L.geometry, R.geometry, 1.0) ORDER BY l_id, r_id";
const DIST_JOIN_SQLS: &[&str] = &[
DIST_JOIN_SQL1,
DIST_JOIN_SQL2,
DIST_JOIN_SQL3,
DIST_JOIN_SQL4,
];
/// Get test data, computing it only once
async fn get_default_test_data() -> &'static (TestPartitions, TestPartitions) {
TEST_DATA
.get_or_init(|| async {
create_default_test_data().expect("Failed to create test data")
})
.await
}
/// Get expected results, computing them only once
async fn get_expected_range_join_results() -> &'static Vec<RecordBatch> {
get_or_init_expected_join_results(&RANGE_JOIN_EXPECTED_RESULTS, RANGE_JOIN_SQLS).await
}
async fn get_expected_distance_join_results() -> &'static Vec<RecordBatch> {
get_or_init_expected_join_results(&DIST_JOIN_EXPECTED_RESULTS, DIST_JOIN_SQLS).await
}
async fn get_or_init_expected_join_results<'a>(
lazy_init_results: &'a OnceCell<Vec<RecordBatch>>,
sql_queries: &[&str],
) -> &'a Vec<RecordBatch> {
lazy_init_results
.get_or_init(|| async {
let test_data = get_default_test_data().await;
let ((left_schema, left_partitions), (right_schema, right_partitions)) = test_data;
let batch_size = 10;
// Run nested loop join to get expected results
let mut expected_results = Vec::with_capacity(sql_queries.len());
for (i, sql) in sql_queries.iter().enumerate() {
let result = run_spatial_join_query(
left_schema,
right_schema,
left_partitions.clone(),
right_partitions.clone(),
None,
batch_size,
sql,
)
.await
.unwrap_or_else(|_| panic!("Failed to generate expected result {}", i + 1));
expected_results.push(result);
}
expected_results
})
.await
}
#[rstest]
#[tokio::test]
async fn test_range_join_with_conf(
#[values(10, 30, 1000)] max_batch_size: usize,
#[values(
ExecutionMode::PrepareNone,
ExecutionMode::PrepareBuild,
ExecutionMode::PrepareProbe,
ExecutionMode::Speculative(20)
)]
execution_mode: ExecutionMode,
#[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
spatial_library: SpatialLibrary,
) -> Result<()> {
let test_data = get_default_test_data().await;
let expected_results = get_expected_range_join_results().await;
let ((left_schema, left_partitions), (right_schema, right_partitions)) = test_data;
let options = SpatialJoinOptions {
spatial_library,
execution_mode,
..Default::default()
};
for (idx, sql) in RANGE_JOIN_SQLS.iter().enumerate() {
let actual_result = run_spatial_join_query(
left_schema,
right_schema,
left_partitions.clone(),
right_partitions.clone(),
Some(options.clone()),
max_batch_size,
sql,
)
.await?;
assert_eq!(&actual_result, &expected_results[idx]);
}
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_distance_join_with_conf(
#[values(30, 1000)] max_batch_size: usize,
#[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
spatial_library: SpatialLibrary,
) -> Result<()> {
let test_data = get_default_test_data().await;
let expected_results = get_expected_distance_join_results().await;
let ((left_schema, left_partitions), (right_schema, right_partitions)) = test_data;
let options = SpatialJoinOptions {
spatial_library,
..Default::default()
};
for (idx, sql) in DIST_JOIN_SQLS.iter().enumerate() {
let actual_result = run_spatial_join_query(
left_schema,
right_schema,
left_partitions.clone(),
right_partitions.clone(),
Some(options.clone()),
max_batch_size,
sql,
)
.await?;
assert_eq!(&actual_result, &expected_results[idx]);
}
Ok(())
}
#[tokio::test]
async fn test_spatial_join_with_filter() -> Result<()> {
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
for max_batch_size in [10, 30, 100] {
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareNone,
..Default::default()
};
test_spatial_join_query(&left_schema, &right_schema, left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
"SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY L.id, R.id").await?;
test_spatial_join_query(&left_schema, &right_schema, left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
"SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY l_id, r_id").await?;
test_spatial_join_query(&left_schema, &right_schema, left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
"SELECT L.id l_id, R.id r_id, L.dist l_dist, R.dist r_dist FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY l_id, r_id").await?;
}
Ok(())
}
#[tokio::test]
async fn test_range_join_with_empty_partitions() -> Result<()> {
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_empty_partitions()?;
for max_batch_size in [10, 30, 1000] {
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareNone,
..Default::default()
};
test_spatial_join_query(&left_schema, &right_schema, left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
"SELECT L.id l_id, R.id r_id FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id").await?;
test_spatial_join_query(&left_schema, &right_schema, left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
"SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY L.id, R.id").await?;
}
Ok(())
}
#[tokio::test]
async fn test_inner_join() -> Result<()> {
test_with_join_types(JoinType::Inner).await?;
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_left_joins(
#[values(JoinType::Left, /* JoinType::LeftSemi, JoinType::LeftAnti */)] join_type: JoinType,
) -> Result<()> {
test_with_join_types(join_type).await?;
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_right_joins(
#[values(JoinType::Right, /* JoinType::RightSemi, JoinType::RightAnti */)]
join_type: JoinType,
) -> Result<()> {
test_with_join_types(join_type).await?;
Ok(())
}
#[tokio::test]
async fn test_full_outer_join() -> Result<()> {
test_with_join_types(JoinType::Full).await?;
Ok(())
}
#[tokio::test]
async fn test_geography_join_is_not_optimized() -> Result<()> {
let options = SpatialJoinOptions::default();
let ctx = setup_context(Some(options), 10)?;
// Prepare geography tables
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_size_range((0.1, 10.0), WKB_GEOGRAPHY)?;
let mem_table_left: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(left_schema, left_partitions)?);
let mem_table_right: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(right_schema, right_partitions)?);
ctx.register_table("L", mem_table_left)?;
ctx.register_table("R", mem_table_right)?;
// Execute geography join query
let df = ctx
.sql("SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry)")
.await?;
let plan = df.create_physical_plan().await?;
// Verify that no SpatialJoinExec is present (geography join should not be optimized)
let spatial_joins = collect_spatial_join_exec(&plan)?;
assert!(
spatial_joins.is_empty(),
"Geography joins should not be optimized to SpatialJoinExec"
);
Ok(())
}
#[tokio::test]
async fn test_query_window_in_subquery() -> Result<()> {
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_size_range((50.0, 60.0), WKB_GEOMETRY)?;
let options = SpatialJoinOptions::default();
test_spatial_join_query(&left_schema, &right_schema, left_partitions.clone(), right_partitions.clone(), &options, 10,
"SELECT id FROM L WHERE ST_Intersects(L.geometry, (SELECT R.geometry FROM R WHERE R.id = 1))").await?;
Ok(())
}
async fn test_with_join_types(join_type: JoinType) -> Result<RecordBatch> {
let ((left_schema, left_partitions), (right_schema, right_partitions)) =
create_test_data_with_empty_partitions()?;
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareNone,
..Default::default()
};
let batch_size = 30;
let inner_sql = "SELECT L.id l_id, R.id r_id FROM L INNER JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
let sql = match join_type {
JoinType::Inner => inner_sql,
JoinType::Left => "SELECT L.id l_id, R.id r_id FROM L LEFT JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
JoinType::Right => "SELECT L.id l_id, R.id r_id FROM L RIGHT JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
JoinType::Full => "SELECT L.id l_id, R.id r_id FROM L FULL OUTER JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
JoinType::LeftSemi => "SELECT L.id l_id FROM L WHERE EXISTS (SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id",
JoinType::RightSemi => "SELECT R.id r_id FROM R WHERE EXISTS (SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id",
JoinType::LeftAnti => "SELECT L.id l_id FROM L WHERE NOT EXISTS (SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id",
JoinType::RightAnti => "SELECT R.id r_id FROM R WHERE NOT EXISTS (SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id",
JoinType::LeftMark => {
unreachable!("LeftMark is not directly supported in SQL, will be tested in other tests");
}
JoinType::RightMark => {
unreachable!("RightMark is not directly supported in SQL, will be tested in other tests");
}
};
let batches = test_spatial_join_query(
&left_schema,
&right_schema,
left_partitions.clone(),
right_partitions.clone(),
&options,
batch_size,
sql,
)
.await?;
if matches!(join_type, JoinType::Left | JoinType::Right | JoinType::Full) {
// Make sure that we are effectively testing outer joins. If outer joins produces the same result as inner join,
// it means that the test data is not suitable for testing outer joins.
let inner_batches = run_spatial_join_query(
&left_schema,
&right_schema,
left_partitions,
right_partitions,
Some(options),
batch_size,
inner_sql,
)
.await?;
assert!(inner_batches.num_rows() < batches.num_rows());
}
Ok(batches)
}
async fn test_spatial_join_query(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
left_partitions: Vec<Vec<RecordBatch>>,
right_partitions: Vec<Vec<RecordBatch>>,
options: &SpatialJoinOptions,
batch_size: usize,
sql: &str,
) -> Result<RecordBatch> {
// Run spatial join using SpatialJoinExec
let actual = run_spatial_join_query(
left_schema,
right_schema,
left_partitions.clone(),
right_partitions.clone(),
Some(options.clone()),
batch_size,
sql,
)
.await?;
// Run spatial join using NestedLoopJoinExec
let expected = run_spatial_join_query(
left_schema,
right_schema,
left_partitions.clone(),
right_partitions.clone(),
None,
batch_size,
sql,
)
.await?;
// Should produce the same result
assert!(expected.num_rows() > 0);
assert_eq!(expected, actual);
Ok(actual)
}
async fn run_spatial_join_query(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
left_partitions: Vec<Vec<RecordBatch>>,
right_partitions: Vec<Vec<RecordBatch>>,
options: Option<SpatialJoinOptions>,
batch_size: usize,
sql: &str,
) -> Result<RecordBatch> {
let mem_table_left: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(left_schema.to_owned(), left_partitions)?);
let mem_table_right: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
right_schema.to_owned(),
right_partitions,
)?);
let is_optimized_spatial_join = options.is_some();
let ctx = setup_context(options, batch_size)?;
ctx.register_table("L", Arc::clone(&mem_table_left))?;
ctx.register_table("R", Arc::clone(&mem_table_right))?;
let df = ctx.sql(sql).await?;
let actual_schema = df.schema().as_arrow().clone();
let plan = df.clone().create_physical_plan().await?;
let spatial_join_execs = collect_spatial_join_exec(&plan)?;
if is_optimized_spatial_join {
assert_eq!(spatial_join_execs.len(), 1);
} else {
assert!(spatial_join_execs.is_empty());
}
let result_batches = df.collect().await?;
let result_batch =
arrow::compute::concat_batches(&Arc::new(actual_schema), &result_batches)?;
Ok(result_batch)
}
fn collect_spatial_join_exec(plan: &Arc<dyn ExecutionPlan>) -> Result<Vec<&SpatialJoinExec>> {
let mut spatial_join_execs = Vec::new();
plan.apply(|node| {
if let Some(spatial_join_exec) = node.as_any().downcast_ref::<SpatialJoinExec>() {
spatial_join_execs.push(spatial_join_exec);
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(spatial_join_execs)
}
}