blob: bd01821b1e0eb8c36bbd7a3047c8fbe858259029 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::sync::Arc;
use crate::exec::SpatialJoinExec;
use crate::spatial_predicate::{
DistancePredicate, KNNPredicate, RelationPredicate, SpatialPredicate, SpatialRelationType,
};
use arrow_schema::{Schema, SchemaRef};
use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
use datafusion::physical_optimizer::sanity_checker::SanityCheckPlan;
use datafusion::{
config::ConfigOptions, execution::session_state::SessionStateBuilder,
physical_optimizer::PhysicalOptimizerRule,
};
use datafusion_common::ScalarValue;
use datafusion_common::{
tree_node::{Transformed, TreeNode},
JoinSide,
};
use datafusion_common::{HashMap, Result};
use datafusion_expr::{Expr, Filter, Join, JoinType, LogicalPlan, Operator};
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec};
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::{joins::utils::JoinFilter, ExecutionPlan};
use sedona_common::{option::SedonaOptions, sedona_internal_err};
use sedona_expr::utils::{parse_distance_predicate, ParsedDistancePredicate};
use sedona_schema::datatypes::SedonaType;
use sedona_schema::matchers::ArgMatcher;
/// Physical planner extension for spatial joins
///
/// This extension recognizes nested loop join operations with spatial predicates
/// and converts them to SpatialJoinExec, which is specially optimized for spatial joins.
#[derive(Debug, Default)]
pub struct SpatialJoinOptimizer;
impl SpatialJoinOptimizer {
pub fn new() -> Self {
Self
}
}
impl PhysicalOptimizerRule for SpatialJoinOptimizer {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let Some(extension) = config.extensions.get::<SedonaOptions>() else {
return Ok(plan);
};
if extension.spatial_join.enable {
let transformed = plan.transform_up(|plan| self.try_optimize_join(plan, config))?;
Ok(transformed.data)
} else {
Ok(plan)
}
}
/// A human readable name for this optimizer rule
fn name(&self) -> &str {
"spatial_join_optimizer"
}
/// A flag to indicate whether the physical planner should valid the rule will not
/// change the schema of the plan after the rewriting.
/// Some of the optimization rules might change the nullable properties of the schema
/// and should disable the schema check.
fn schema_check(&self) -> bool {
true
}
}
impl OptimizerRule for SpatialJoinOptimizer {
fn name(&self) -> &str {
"spatial_join_optimizer"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
/// Try to rewrite the plan containing a spatial Filter on top of a cross join without on or filter
/// to a theta-join with filter. For instance, the following query plan:
///
/// ```text
/// Filter: st_intersects(l.geom, _scalar_sq_1.geom)
/// Left Join (no on, no filter):
/// TableScan: l projection=[id, geom]
/// SubqueryAlias: __scalar_sq_1
/// Projection: r.geom
/// Filter: r.id = Int32(1)
/// TableScan: r projection=[id, geom]
/// ```
///
/// will be rewritten to
///
/// ```text
/// Inner Join: Filter: st_intersects(l.geom, _scalar_sq_1.geom)
/// TableScan: l projection=[id, geom]
/// SubqueryAlias: __scalar_sq_1
/// Projection: r.geom
/// Filter: r.id = Int32(1)
/// TableScan: r projection=[id, geom]
/// ```
///
/// This is for enabling this logical join operator to be converted to a NestedLoopJoin physical
/// node with a spatial predicate, so that it could subsequently be optimized to a SpatialJoin
/// physical node. Please refer to the `PhysicalOptimizerRule` implementation of this struct
/// and [SpatialJoinOptimizer::try_optimize_join] for details.
fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let options = config.options();
let Some(extension) = options.extensions.get::<SedonaOptions>() else {
return Ok(Transformed::no(plan));
};
if !extension.spatial_join.enable {
return Ok(Transformed::no(plan));
}
let LogicalPlan::Filter(Filter {
predicate, input, ..
}) = &plan
else {
return Ok(Transformed::no(plan));
};
if !is_spatial_predicate(predicate) {
return Ok(Transformed::no(plan));
}
let LogicalPlan::Join(Join {
ref left,
ref right,
ref on,
ref filter,
join_type,
ref join_constraint,
ref null_equality,
..
}) = input.as_ref()
else {
return Ok(Transformed::no(plan));
};
// Check if this is a suitable join for rewriting
if !matches!(
join_type,
JoinType::Inner | JoinType::Left | JoinType::Right
) || !on.is_empty()
|| filter.is_some()
{
return Ok(Transformed::no(plan));
}
let rewritten_plan = Join::try_new(
Arc::clone(left),
Arc::clone(right),
on.clone(),
Some(predicate.clone()),
JoinType::Inner,
*join_constraint,
*null_equality,
)?;
Ok(Transformed::yes(LogicalPlan::Join(rewritten_plan)))
}
}
/// Check if a given logical expression contains a spatial predicate component or not. We assume that the given
/// `expr` evaluates to a boolean value and originates from a filter logical node.
fn is_spatial_predicate(expr: &Expr) -> bool {
fn is_distance_expr(expr: &Expr) -> bool {
let Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. }) = expr else {
return false;
};
func.name().to_lowercase() == "st_distance"
}
match expr {
Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left, right, op, ..
}) => match op {
Operator::And => is_spatial_predicate(left) || is_spatial_predicate(right),
Operator::Lt | Operator::LtEq => is_distance_expr(left),
Operator::Gt | Operator::GtEq => is_distance_expr(right),
_ => false,
},
Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. }) => {
let func_name = func.name().to_lowercase();
matches!(
func_name.as_str(),
"st_intersects"
| "st_contains"
| "st_within"
| "st_covers"
| "st_covered_by"
| "st_coveredby"
| "st_touches"
| "st_crosses"
| "st_overlaps"
| "st_equals"
| "st_dwithin"
| "st_knn"
)
}
_ => false,
}
}
impl SpatialJoinOptimizer {
/// Rewrite `plan` containing NestedLoopJoinExec or HashJoinExec with spatial predicates to SpatialJoinExec.
fn try_optimize_join(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
// Check if this is a NestedLoopJoinExec that we can convert to spatial join
if let Some(nested_loop_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
if let Some(spatial_join) = self.try_convert_to_spatial_join(nested_loop_join)? {
return Ok(Transformed::yes(spatial_join));
}
}
// Check if this is a HashJoinExec with spatial filter that we can convert to spatial join
if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
if let Some(spatial_join) = self.try_convert_hash_join_to_spatial(hash_join)? {
return Ok(Transformed::yes(spatial_join));
}
}
// No optimization applied, return the original plan
Ok(Transformed::no(plan))
}
/// Try to convert a NestedLoopJoinExec with spatial predicates as join condition to a SpatialJoinExec.
/// SpatialJoinExec executes the query using an optimized algorithm, which is more efficient than
/// NestedLoopJoinExec.
fn try_convert_to_spatial_join(
&self,
nested_loop_join: &NestedLoopJoinExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
if let Some(join_filter) = nested_loop_join.filter() {
if let Some((spatial_predicate, remainder)) = transform_join_filter(join_filter) {
// The left side of the nested loop join is required to have only one partition, while SpatialJoinExec
// does not have that requirement. SpatialJoinExec can consume the streams on the build side in parallel
// when the build side has multiple partitions.
// If the left side is a CoalescePartitionsExec, we can drop the CoalescePartitionsExec and directly use
// the input.
let left = nested_loop_join.left();
let left = if let Some(coalesce_partitions) =
left.as_any().downcast_ref::<CoalescePartitionsExec>()
{
// Remove unnecessary CoalescePartitionsExec for spatial joins
coalesce_partitions.input()
} else {
left
};
let left = left.clone();
let right = nested_loop_join.right().clone();
let join_type = nested_loop_join.join_type();
// Check if the geospatial types involved in spatial_predicate are supported
if !is_spatial_predicate_supported(
&spatial_predicate,
&left.schema(),
&right.schema(),
)? {
return Ok(None);
}
// Create the spatial join
let spatial_join = SpatialJoinExec::try_new(
left,
right,
spatial_predicate,
remainder,
join_type,
nested_loop_join.projection().cloned(),
)?;
return Ok(Some(Arc::new(spatial_join)));
}
}
Ok(None)
}
/// Try to convert a HashJoinExec with spatial predicates in the filter to a SpatialJoinExec.
/// This handles cases where there's an equi-join condition (like c.id = r.id) along with
/// the ST_KNN predicate. We flip them so the spatial predicate drives the join
/// and the equi-conditions become filters.
fn try_convert_hash_join_to_spatial(
&self,
hash_join: &HashJoinExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
// Check if the filter contains spatial predicates
if let Some(join_filter) = hash_join.filter() {
if let Some((spatial_predicate, mut remainder)) = transform_join_filter(join_filter) {
// The transform_join_filter now prioritizes ST_KNN predicates
// Only proceed if we found an ST_KNN (other spatial predicates are left in hash join)
if !matches!(spatial_predicate, SpatialPredicate::KNearestNeighbors(_)) {
return Ok(None);
}
// Check if the geospatial types involved in spatial_predicate are supported (planar geometries only)
if !is_spatial_predicate_supported(
&spatial_predicate,
&hash_join.left().schema(),
&hash_join.right().schema(),
)? {
return Ok(None);
}
// Extract the equi-join conditions and convert them to a filter
let equi_filter = self.create_equi_filter_from_hash_join(hash_join)?;
// Combine the equi-filter with any existing remainder
remainder = self.combine_filters(remainder, equi_filter)?;
// Create spatial join where:
// - Spatial predicate (ST_KNN) drives the join
// - Equi-conditions (c.id = r.id) become filters
// Create SpatialJoinExec without projection first
// Use try_new_with_options to mark this as converted from HashJoin
let spatial_join = Arc::new(SpatialJoinExec::try_new_with_options(
hash_join.left().clone(),
hash_join.right().clone(),
spatial_predicate,
remainder,
hash_join.join_type(),
None, // No projection in SpatialJoinExec
true, // converted_from_hash_join = true
)?);
// Now wrap it with ProjectionExec to match HashJoinExec's output schema exactly
let expected_schema = hash_join.schema();
let spatial_schema = spatial_join.schema();
// Create a projection that selects the exact columns HashJoinExec would output
let projection_exec = self.create_schema_matching_projection(
spatial_join,
&expected_schema,
&spatial_schema,
)?;
return Ok(Some(projection_exec));
}
}
Ok(None)
}
/// Create a filter expression from the hash join's equi-join conditions
fn create_equi_filter_from_hash_join(
&self,
hash_join: &HashJoinExec,
) -> Result<Option<JoinFilter>> {
let join_keys = hash_join.on();
if join_keys.is_empty() {
return Ok(None);
}
// Build filter expressions from the equi-join conditions
let mut expressions = vec![];
// Get the left schema size to calculate right column offsets
let left_schema_size = hash_join.left().schema().fields().len();
for (left_key, right_key) in join_keys.iter() {
// Create equality expression: left_key = right_key
// But we need to adjust the column indices for SpatialJoinExec schema
if let (Some(left_col), Some(right_col)) = (
left_key.as_any().downcast_ref::<Column>(),
right_key.as_any().downcast_ref::<Column>(),
) {
// In SpatialJoinExec schema: [left_fields..., right_fields...]
// Left columns keep their indices, right columns get offset by left_schema_size
let left_idx = left_col.index();
let right_idx = left_schema_size + right_col.index();
let left_expr =
Arc::new(Column::new(left_col.name(), left_idx)) as Arc<dyn PhysicalExpr>;
let right_expr =
Arc::new(Column::new(right_col.name(), right_idx)) as Arc<dyn PhysicalExpr>;
let eq_expr = Arc::new(BinaryExpr::new(left_expr, Operator::Eq, right_expr))
as Arc<dyn PhysicalExpr>;
expressions.push(eq_expr);
}
}
// IMPORTANT: Create column indices for ALL columns in the spatial join schema
// not just the filter columns. This is required by build_batch_from_indices.
let left_schema = hash_join.left().schema();
let right_schema = hash_join.right().schema();
let mut column_indices = vec![];
// Add all left side columns
for (i, _field) in left_schema.fields().iter().enumerate() {
column_indices.push(ColumnIndex {
index: i,
side: JoinSide::Left,
});
}
// Add all right side columns
for (i, _field) in right_schema.fields().iter().enumerate() {
column_indices.push(ColumnIndex {
index: i,
side: JoinSide::Right,
});
}
// Combine all conditions with AND
let filter_expr = if expressions.len() == 1 {
expressions.into_iter().next().unwrap()
} else {
expressions
.into_iter()
.reduce(|acc, expr| {
Arc::new(BinaryExpr::new(acc, Operator::And, expr)) as Arc<dyn PhysicalExpr>
})
.unwrap()
};
// Create JoinFilter
// IMPORTANT: The filter expression uses spatial join indices (id@0 = id@3)
// So we need to create the filter schema that matches the spatial join schema,
// not the hash join schema
let left_schema = hash_join.left().schema();
let right_schema = hash_join.right().schema();
let mut spatial_filter_fields = left_schema.fields().to_vec();
spatial_filter_fields.extend_from_slice(right_schema.fields());
let spatial_filter_schema = Arc::new(arrow_schema::Schema::new(spatial_filter_fields));
// Filter expression uses spatial join indices (e.g. id@0 = id@3)
// Schema should match the spatial join schema (left + right)
Ok(Some(JoinFilter::new(
filter_expr,
column_indices,
spatial_filter_schema,
)))
}
/// Combine two optional filters with AND
fn combine_filters(
&self,
filter1: Option<JoinFilter>,
filter2: Option<JoinFilter>,
) -> Result<Option<JoinFilter>> {
match (filter1, filter2) {
(None, None) => Ok(None),
(Some(f), None) | (None, Some(f)) => Ok(Some(f)),
(Some(f1), Some(f2)) => {
// Combine f1 AND f2
let combined_expr = Arc::new(BinaryExpr::new(
f1.expression().clone(),
Operator::And,
f2.expression().clone(),
)) as Arc<dyn PhysicalExpr>;
// Combine column indices
let mut combined_indices = f1.column_indices().to_vec();
combined_indices.extend_from_slice(f2.column_indices());
Ok(Some(JoinFilter::new(
combined_expr,
combined_indices,
f1.schema().clone(),
)))
}
}
}
/// Create a ProjectionExec that makes SpatialJoinExec output match HashJoinExec's schema
fn create_schema_matching_projection(
&self,
spatial_join: Arc<SpatialJoinExec>,
expected_schema: &SchemaRef,
spatial_schema: &SchemaRef,
) -> Result<Arc<dyn ExecutionPlan>> {
// The challenge is to map from the expected HashJoinExec schema to SpatialJoinExec schema
//
// Expected schema has fields like: [id, name, name] (with duplicates)
// Spatial schema has fields like: [id, location, name, id, location, name] (left + right)
// Map the expected schema to spatial schema by matching field names and types
// For fields with duplicate names (like "name"), we need to be careful about ordering
let mut projection_exprs = Vec::new();
let mut used_spatial_indices = std::collections::HashSet::new();
for (expected_idx, expected_field) in expected_schema.fields().iter().enumerate() {
let mut found = false;
// Try to find the corresponding field in spatial schema
for (spatial_idx, spatial_field) in spatial_schema.fields().iter().enumerate() {
if spatial_field.name() == expected_field.name()
&& spatial_field.data_type() == expected_field.data_type()
&& !used_spatial_indices.contains(&spatial_idx)
{
let col_expr = Arc::new(Column::new(spatial_field.name(), spatial_idx))
as Arc<dyn PhysicalExpr>;
projection_exprs.push((col_expr, expected_field.name().clone()));
used_spatial_indices.insert(spatial_idx);
found = true;
break;
}
}
if !found {
return sedona_internal_err!(
"Cannot find matching field for '{}' ({:?}) at position {} in spatial join output. \
Please check column name mappings and schema compatibility between HashJoinExec and SpatialJoinExec.",
expected_field.name(),
expected_field.data_type(),
expected_idx
);
}
}
let projection = ProjectionExec::try_new(projection_exprs, spatial_join)?;
Ok(Arc::new(projection))
}
}
/// Helper function to register the spatial join optimizer with a session state
pub fn register_spatial_join_optimizer(
session_state_builder: SessionStateBuilder,
) -> SessionStateBuilder {
session_state_builder
.with_optimizer_rule(Arc::new(SpatialJoinOptimizer::new()))
.with_physical_optimizer_rule(Arc::new(SpatialJoinOptimizer::new()))
.with_physical_optimizer_rule(Arc::new(SanityCheckPlan::new()))
}
/// Transform the join filter to a spatial predicate and a remainder.
///
/// * The spatial predicate is a spatial predicate that is extracted from the join filter.
/// * The remainder is everything other than the spatial predicate.
///
/// The remainder may reference fewer columns than the original join filter. If that's the case,
/// the columns that are not referenced by the remainder will be pruned.
fn transform_join_filter(
join_filter: &JoinFilter,
) -> Option<(SpatialPredicate, Option<JoinFilter>)> {
let (spatial_predicate, remainder) =
extract_spatial_predicate(join_filter.expression(), join_filter.column_indices())?;
let remainder = remainder
.as_ref()
.map(|remainder| replace_join_filter_expr(remainder, join_filter));
Some((spatial_predicate, remainder))
}
/// Extract the spatial predicate from the join filter. The extracted spatial predicate and the remaining filter
/// are returned. ST_KNN predicates are prioritized since they cannot be used as filters.
fn extract_spatial_predicate(
expr: &Arc<dyn PhysicalExpr>,
column_indices: &[ColumnIndex],
) -> Option<(SpatialPredicate, Option<Arc<dyn PhysicalExpr>>)> {
// First, scan the entire expression tree for ST_KNN predicates
// ST_KNN must be the join condition since it cannot be a filter
if let Some((knn_predicate, remainder)) =
extract_knn_predicate_prioritized(expr, column_indices)
{
return Some((
SpatialPredicate::KNearestNeighbors(knn_predicate),
remainder,
));
}
// No ST_KNN found, proceed with normal extraction
if let Some(scalar_fn) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
if let Some(relation_predicate) = match_relation_predicate(scalar_fn, column_indices) {
return Some((SpatialPredicate::Relation(relation_predicate), None));
}
}
if let Some(distance_predicate) = match_distance_predicate(expr, column_indices) {
return Some((SpatialPredicate::Distance(distance_predicate), None));
}
if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
if !matches!(binary_expr.op(), Operator::And) {
return None;
}
let left = binary_expr.left();
let right = binary_expr.right();
// Try to extract the spatial predicate from the left side
if let Some((spatial_predicate, remainder)) =
extract_spatial_predicate(left, column_indices)
{
let combined_remainder = match remainder {
Some(remainder) => {
Arc::new(BinaryExpr::new(remainder, Operator::And, right.clone()))
}
None => right.clone(),
};
return Some((spatial_predicate, Some(combined_remainder)));
}
// Left side is not a spatial predicate, try to extract the spatial predicate from the right side
if let Some((spatial_predicate, remainder)) =
extract_spatial_predicate(right, column_indices)
{
let combined_remainder = match remainder {
Some(remainder) => {
Arc::new(BinaryExpr::new(left.clone(), Operator::And, remainder))
}
None => left.clone(),
};
return Some((spatial_predicate, Some(combined_remainder)));
}
}
None
}
/// Extract ST_KNN predicate from anywhere in the expression tree, collecting all other predicates as remainder
fn extract_knn_predicate_prioritized(
expr: &Arc<dyn PhysicalExpr>,
column_indices: &[ColumnIndex],
) -> Option<(KNNPredicate, Option<Arc<dyn PhysicalExpr>>)> {
// Check if this expression itself is ST_KNN
if let Some(scalar_fn) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
if let Some(knn_predicate) = match_knn_predicate(scalar_fn, column_indices) {
return Some((knn_predicate, None));
}
}
// If this is an AND expression, check both sides for ST_KNN
if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
if matches!(binary_expr.op(), Operator::And) {
let left = binary_expr.left();
let right = binary_expr.right();
// Check if left side contains ST_KNN
if let Some((knn_predicate, left_remainder)) =
extract_knn_predicate_prioritized(left, column_indices)
{
// ST_KNN found in left side, combine any left remainder with right side
let combined_remainder = match left_remainder {
Some(remainder) => Some(Arc::new(BinaryExpr::new(
remainder,
Operator::And,
right.clone(),
)) as Arc<dyn PhysicalExpr>),
None => Some(right.clone()),
};
return Some((knn_predicate, combined_remainder));
}
// Check if right side contains ST_KNN
if let Some((knn_predicate, right_remainder)) =
extract_knn_predicate_prioritized(right, column_indices)
{
// ST_KNN found in right side, combine left side with any right remainder
let combined_remainder = match right_remainder {
Some(remainder) => Some(Arc::new(BinaryExpr::new(
left.clone(),
Operator::And,
remainder,
)) as Arc<dyn PhysicalExpr>),
None => Some(left.clone()),
};
return Some((knn_predicate, combined_remainder));
}
}
}
None
}
/// Match the scalar function expression to a spatial relation predicate such as ST_Intersects(lhs.geom, rhs.geom).
/// The input arguments of the ST_ function should reference columns from different sides.
fn match_relation_predicate(
scalar_fn: &ScalarFunctionExpr,
column_indices: &[ColumnIndex],
) -> Option<RelationPredicate> {
if let Some(relation_type) = SpatialRelationType::from_name(scalar_fn.fun().name()) {
// Try to find the expressions that evaluates to the arguments of the spatial function
let args = scalar_fn.args();
assert!(args.len() >= 2);
let arg0 = &args[0];
let arg1 = &args[1];
// Try to find the expressions that evaluates to the arguments of the spatial function
let arg0_refs = collect_column_references(arg0, column_indices);
let arg1_refs = collect_column_references(arg1, column_indices);
let (arg0_side, arg1_side) = resolve_column_reference_sides(&arg0_refs, &arg1_refs)?;
let arg0_reprojected =
reproject_column_references_for_side(arg0, column_indices, arg0_side);
let arg1_reprojected =
reproject_column_references_for_side(arg1, column_indices, arg1_side);
return match (arg0_side, arg1_side) {
(JoinSide::Left, JoinSide::Right) => Some(RelationPredicate::new(
arg0_reprojected,
arg1_reprojected,
relation_type,
)),
(JoinSide::Right, JoinSide::Left) => {
// The spatial predicate needs to be inverted
Some(RelationPredicate::new(
arg1_reprojected,
arg0_reprojected,
relation_type.invert(),
))
}
_ => None,
};
}
None
}
/// Match the scalar function expression to a distance predicate such as ST_DWithin(geom1, geom2, distance)
/// or ST_Distance(geom1, geom2) <= distance.
/// The geometry input arguments of the ST_ function should reference columns from different sides.
/// The distance input argument should not reference columns from both sides simultaneously.
fn match_distance_predicate(
expr: &Arc<dyn PhysicalExpr>,
column_indices: &[ColumnIndex],
) -> Option<DistancePredicate> {
let ParsedDistancePredicate {
arg0,
arg1,
arg_distance,
} = parse_distance_predicate(expr)?;
// Try to find the expressions that evaluates to the arguments of the spatial function
let arg0_refs = collect_column_references(&arg0, column_indices);
let arg1_refs = collect_column_references(&arg1, column_indices);
let arg_dist_refs = collect_column_references(&arg_distance, column_indices);
let arg_dist_side = side_of_column_references(&arg_dist_refs)?;
let (arg0_side, arg1_side) = resolve_column_reference_sides(&arg0_refs, &arg1_refs)?;
let arg0_reprojected = reproject_column_references_for_side(&arg0, column_indices, arg0_side);
let arg1_reprojected = reproject_column_references_for_side(&arg1, column_indices, arg1_side);
let arg_dist_reprojected =
reproject_column_references_for_side(&arg_distance, column_indices, arg_dist_side);
match (arg0_side, arg1_side) {
(JoinSide::Left, JoinSide::Right) => Some(DistancePredicate::new(
arg0_reprojected,
arg1_reprojected,
arg_dist_reprojected,
arg_dist_side,
)),
(JoinSide::Right, JoinSide::Left) => Some(DistancePredicate::new(
arg1_reprojected,
arg0_reprojected,
arg_dist_reprojected,
arg_dist_side,
)),
_ => None,
}
}
/// Match the scalar function expression to a KNN predicate such as ST_KNN(geom1, geom2, k, use_spheroid).
/// The geometry input arguments of the ST_KNN function should reference columns from different sides.
/// The k and use_spheroid arguments must be literal values.
fn match_knn_predicate(
scalar_fn: &ScalarFunctionExpr,
column_indices: &[ColumnIndex],
) -> Option<KNNPredicate> {
// Check if this is an ST_KNN function
if scalar_fn.fun().name() != "st_knn" {
return None;
}
let args = scalar_fn.args();
if args.len() < 4 {
return None; // ST_KNN requires 4 arguments: (queries_geom, objects_geom, k, use_spheroid)
}
let queries_geom = &args[0];
let objects_geom = &args[1];
let k_expr = &args[2];
let use_spheroid_expr = &args[3];
// Extract literal values for k and use_spheroid
let k = extract_literal_u32(k_expr)?;
let use_spheroid = extract_literal_bool(use_spheroid_expr)?;
// Collect column references for geometry arguments
let queries_refs = collect_column_references(queries_geom, column_indices);
let objects_refs = collect_column_references(objects_geom, column_indices);
let (queries_side, objects_side) =
resolve_column_reference_sides(&queries_refs, &objects_refs)?;
// Reproject geometry arguments to their respective sides
let queries_reprojected =
reproject_column_references_for_side(queries_geom, column_indices, queries_side);
let objects_reprojected =
reproject_column_references_for_side(objects_geom, column_indices, objects_side);
match (queries_side, objects_side) {
(JoinSide::Left, JoinSide::Right) => {
Some(KNNPredicate::new(
queries_reprojected,
objects_reprojected,
k,
use_spheroid,
JoinSide::Left, // Probe side is left plan
))
}
(JoinSide::Right, JoinSide::Left) => {
// Preserve the original query semantics: first argument is always probe, second is always build
Some(KNNPredicate::new(
queries_reprojected, // First argument (probe side)
objects_reprojected, // Second argument (build side)
k,
use_spheroid,
JoinSide::Right, // Probe side is right plan (since queries_side=Right)
))
}
_ => None,
}
}
fn collect_column_references(
expr: &Arc<dyn PhysicalExpr>,
column_indices: &[ColumnIndex],
) -> Vec<ColumnIndex> {
let mut collected_column_indices = Vec::with_capacity(column_indices.len());
expr.apply(|node| {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
let intermediate_index = column.index();
let column_info = &column_indices[intermediate_index];
collected_column_indices.push(column_info.clone());
}
Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
})
.expect("Failed to collect column references");
collected_column_indices
}
fn resolve_column_reference_sides(
left_refs: &[ColumnIndex],
right_refs: &[ColumnIndex],
) -> Option<(JoinSide, JoinSide)> {
let left_side = side_of_column_references(left_refs)?;
let right_side = side_of_column_references(right_refs)?;
if left_side != right_side {
Some((left_side, right_side))
} else {
None
}
}
fn side_of_column_references(column_indices: &[ColumnIndex]) -> Option<JoinSide> {
match column_indices.first() {
Some(first) => {
let first_side = first.side;
if column_indices
.iter()
.all(|col_idx| col_idx.side == first_side)
{
Some(first_side)
} else {
// Referencing both sides simultaneously
None
}
}
None => Some(JoinSide::None),
}
}
fn reproject_column_references(
expr: &Arc<dyn PhysicalExpr>,
index_map: &HashMap<usize, usize>,
) -> Arc<dyn PhysicalExpr> {
expr.clone()
.transform_down(|node| {
// Check if this is a Column expression
if let Some(column) = node.as_any().downcast_ref::<Column>() {
let old_index = column.index();
if let Some(&new_index) = index_map.get(&old_index) {
// Create a new Column with the mapped index
let new_column = Arc::new(Column::new(column.name(), new_index));
return Ok(Transformed::yes(new_column));
}
}
// For all other expressions, continue with the default traversal
Ok(Transformed::no(node))
})
.unwrap_or_else(|_| Transformed::no(expr.clone()))
.data
}
fn reproject_column_references_for_side(
expr: &Arc<dyn PhysicalExpr>,
column_indices: &[ColumnIndex],
side: JoinSide,
) -> Arc<dyn PhysicalExpr> {
if side == JoinSide::None {
return expr.clone();
}
let index_mapping: HashMap<usize, usize> = column_indices
.iter()
.enumerate()
.filter_map(|(i, col_idx)| (col_idx.side == side).then_some((i, col_idx.index)))
.collect();
reproject_column_references(expr, &index_mapping)
}
/// Extract a literal u32 value from an expression.
/// Returns None if the expression is not a literal integer or if it's out of u32 range.
fn extract_literal_u32(expr: &Arc<dyn PhysicalExpr>) -> Option<u32> {
let literal = expr.as_any().downcast_ref::<Literal>()?;
match literal.value() {
ScalarValue::UInt32(Some(val)) => Some(*val),
ScalarValue::Int32(Some(val)) if *val >= 0 => Some(*val as u32),
ScalarValue::Int64(Some(val)) if *val >= 0 && *val <= u32::MAX as i64 => Some(*val as u32),
ScalarValue::UInt64(Some(val)) if *val <= u32::MAX as u64 => Some(*val as u32),
_ => None,
}
}
/// Extract a literal boolean value from an expression.
/// Returns None if the expression is not a literal boolean.
fn extract_literal_bool(expr: &Arc<dyn PhysicalExpr>) -> Option<bool> {
let literal = expr.as_any().downcast_ref::<Literal>()?;
match literal.value() {
ScalarValue::Boolean(Some(val)) => Some(*val),
_ => None,
}
}
/// Replace the join filter expression with a new expression. The replaced join filter expression
/// may reference fewer columns than the original join filter expression. If that's the case,
/// the columns that are not referenced by the replaced join filter expression will be pruned.
fn replace_join_filter_expr(expr: &Arc<dyn PhysicalExpr>, join_filter: &JoinFilter) -> JoinFilter {
let column_indices = join_filter.column_indices();
let column_refs = collect_column_references(expr, column_indices);
// column_refs could be a subset of column_indices. If that's the case, we can prune column_indices
// to only include the columns that are referenced by the remainder.
let referenced_columns: Vec<_> = column_indices
.iter()
.enumerate()
.filter(|(_, col_idx)| column_refs.contains(col_idx))
.collect();
let pruned_column_indices: Vec<_> = referenced_columns
.iter()
.map(|(_, col_idx)| (*col_idx).clone())
.collect();
let column_index_mapping: HashMap<_, _> = referenced_columns
.iter()
.enumerate()
.map(|(new_idx, (old_idx, _))| (*old_idx, new_idx))
.collect();
let project: Vec<_> = referenced_columns
.iter()
.map(|(old_idx, _)| *old_idx)
.collect();
let pruned_schema = join_filter
.schema()
.project(&project)
.expect("Failed to project schema");
let remainder_reprojected = reproject_column_references(expr, &column_index_mapping);
JoinFilter::new(
remainder_reprojected,
pruned_column_indices,
Arc::new(pruned_schema),
)
}
fn is_spatial_predicate_supported(
spatial_predicate: &SpatialPredicate,
left_schema: &Schema,
right_schema: &Schema,
) -> Result<bool> {
/// Only spatial predicates working with planar geometry are supported for optimization.
/// Geography (spherical) types are explicitly excluded and will not trigger optimized spatial joins.
fn is_geometry_type_supported(expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<bool> {
let left_return_field = expr.return_field(schema)?;
let sedona_type = SedonaType::from_storage_field(&left_return_field)?;
let matcher = ArgMatcher::is_geometry();
Ok(matcher.match_type(&sedona_type))
}
match spatial_predicate {
SpatialPredicate::Relation(RelationPredicate { left, right, .. })
| SpatialPredicate::Distance(DistancePredicate { left, right, .. }) => {
Ok(is_geometry_type_supported(left, left_schema)?
&& is_geometry_type_supported(right, right_schema)?)
}
SpatialPredicate::KNearestNeighbors(KNNPredicate {
left,
right,
probe_side,
..
}) => {
let (left, right) = match probe_side {
JoinSide::Left => (left, right),
JoinSide::Right => (right, left),
_ => {
return sedona_internal_err!(
"Invalid probe side in KNN predicate: {:?}",
probe_side
)
}
};
Ok(is_geometry_type_supported(left, left_schema)?
&& is_geometry_type_supported(right, right_schema)?)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spatial_predicate::{SpatialPredicate, SpatialRelationType};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{JoinSide, ScalarValue};
use datafusion_expr::Operator;
use datafusion_expr::{col, lit, ColumnarValue, Expr, ScalarUDF, SimpleScalarUDF};
use datafusion_physical_expr::expressions::{BinaryExpr, Column, IsNotNullExpr, Literal};
use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::utils::JoinFilter;
use sedona_schema::datatypes::{WKB_GEOGRAPHY, WKB_GEOMETRY};
use std::sync::Arc;
// Helper function to create a test schema
fn create_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("left_id", DataType::Int32, false), // index 0
WKB_GEOMETRY.to_storage_field("left_geom", false).unwrap(), // index 1
WKB_GEOMETRY.to_storage_field("right_geom", false).unwrap(), // index 2
Field::new("right_distance", DataType::Float64, false), // index 3
]))
}
// Helper function to create test column indices for join filter
fn create_test_column_indices() -> Vec<ColumnIndex> {
vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
}, // left_id
ColumnIndex {
index: 1,
side: JoinSide::Left,
}, // left_geom
ColumnIndex {
index: 0,
side: JoinSide::Right,
}, // right_geom
ColumnIndex {
index: 1,
side: JoinSide::Right,
}, // right_distance
]
}
// Helper to create dummy spatial UDFs for testing
fn create_dummy_st_intersects_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::from(SimpleScalarUDF::new(
"st_intersects",
vec![
WKB_GEOMETRY.storage_type().clone(),
WKB_GEOMETRY.storage_type().clone(),
],
DataType::Boolean,
datafusion_expr::Volatility::Immutable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))),
)))
}
fn create_dummy_st_dwithin_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::from(SimpleScalarUDF::new(
"st_dwithin",
vec![
WKB_GEOMETRY.storage_type().clone(),
WKB_GEOMETRY.storage_type().clone(),
DataType::Float64,
],
DataType::Boolean,
datafusion_expr::Volatility::Immutable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))),
)))
}
fn create_dummy_st_distance_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::from(SimpleScalarUDF::new(
"st_distance",
vec![
WKB_GEOMETRY.storage_type().clone(),
WKB_GEOMETRY.storage_type().clone(),
],
DataType::Float64,
datafusion_expr::Volatility::Immutable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(100.0))))),
)))
}
fn create_dummy_st_within_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::from(SimpleScalarUDF::new(
"st_within",
vec![
WKB_GEOMETRY.storage_type().clone(),
WKB_GEOMETRY.storage_type().clone(),
],
DataType::Boolean,
datafusion_expr::Volatility::Immutable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))),
)))
}
// Helper to create spatial function expressions using the dummy UDFs
fn create_spatial_function_expr(
udf: Arc<ScalarUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
) -> Arc<ScalarFunctionExpr> {
let return_type = udf.return_type(&[]).unwrap();
let field = Arc::new(arrow::datatypes::Field::new("result", return_type, false));
// TODO: Pipe actual ConfigOptions from session instead of using defaults
// See: https://github.com/apache/sedona-db/issues/248
Arc::new(ScalarFunctionExpr::new(
udf.name(),
Arc::clone(&udf),
args,
field,
Arc::new(ConfigOptions::default()),
))
}
#[test]
fn test_collect_column_references() {
let column_indices = create_test_column_indices();
// Test single column reference
let col_expr = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let refs = collect_column_references(&col_expr, &column_indices);
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].index, 1);
assert_eq!(refs[0].side, JoinSide::Left);
// Test binary expression with columns from both sides
let left_col = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_col = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let binary_expr =
Arc::new(BinaryExpr::new(left_col, Operator::Eq, right_col)) as Arc<dyn PhysicalExpr>;
let refs = collect_column_references(&binary_expr, &column_indices);
assert_eq!(refs.len(), 2);
// Should have one reference from each side
let left_refs: Vec<_> = refs.iter().filter(|r| r.side == JoinSide::Left).collect();
let right_refs: Vec<_> = refs.iter().filter(|r| r.side == JoinSide::Right).collect();
assert_eq!(left_refs.len(), 1);
assert_eq!(right_refs.len(), 1);
// Test literal expression with no column references
let literal_expr =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let refs = collect_column_references(&literal_expr, &column_indices);
assert_eq!(refs.len(), 0);
}
#[test]
fn test_side_of_column_references() {
// Test all left side
let left_refs = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
];
assert_eq!(side_of_column_references(&left_refs), Some(JoinSide::Left));
// Test all right side
let right_refs = vec![ColumnIndex {
index: 0,
side: JoinSide::Right,
}];
assert_eq!(
side_of_column_references(&right_refs),
Some(JoinSide::Right)
);
// Test mixed sides (should return None)
let mixed_refs = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
assert_eq!(side_of_column_references(&mixed_refs), None);
// Test empty (should return JoinSide::None)
let empty_refs = vec![];
assert_eq!(side_of_column_references(&empty_refs), Some(JoinSide::None));
}
#[test]
fn test_resolve_column_reference_sides() {
let left_refs = vec![ColumnIndex {
index: 0,
side: JoinSide::Left,
}];
let right_refs = vec![ColumnIndex {
index: 0,
side: JoinSide::Right,
}];
let mixed_refs = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
// Test valid case - different sides
let result = resolve_column_reference_sides(&left_refs, &right_refs);
assert_eq!(result, Some((JoinSide::Left, JoinSide::Right)));
let result = resolve_column_reference_sides(&right_refs, &left_refs);
assert_eq!(result, Some((JoinSide::Right, JoinSide::Left)));
// Test invalid case - same side
let result = resolve_column_reference_sides(&left_refs, &left_refs);
assert_eq!(result, None);
let result = resolve_column_reference_sides(&right_refs, &right_refs);
assert_eq!(result, None);
// Test invalid case - mixed sides
let result = resolve_column_reference_sides(&left_refs, &mixed_refs);
assert_eq!(result, None);
let result = resolve_column_reference_sides(&mixed_refs, &mixed_refs);
assert_eq!(result, None);
}
#[test]
fn test_reproject_column_references() {
// Create a column expression
let col_expr = Arc::new(Column::new("test_col", 2)) as Arc<dyn PhysicalExpr>;
// Create index mapping: old index 2 -> new index 0
let mut index_map = HashMap::new();
index_map.insert(2, 0);
let reprojected = reproject_column_references(&col_expr, &index_map);
// Check that the column index was updated
let reprojected_col = reprojected.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(reprojected_col.index(), 0);
assert_eq!(reprojected_col.name(), "test_col");
// Test expression with no mapping (should remain unchanged)
let col_expr_unmapped = Arc::new(Column::new("other_col", 5)) as Arc<dyn PhysicalExpr>;
let reprojected_unmapped = reproject_column_references(&col_expr_unmapped, &index_map);
let unmapped_col = reprojected_unmapped
.as_any()
.downcast_ref::<Column>()
.unwrap();
assert_eq!(unmapped_col.index(), 5); // Should remain unchanged
}
#[test]
fn test_reproject_column_reference_complex() {
let udf = create_dummy_st_intersects_udf();
let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
create_spatial_function_expr(
udf,
vec![
Arc::new(Column::new("left_geom", 1)),
Arc::new(Column::new("right_geom", 2)),
],
),
Operator::And,
Arc::new(BinaryExpr::new(
Arc::new(IsNotNullExpr::new(Arc::new(Column::new("left_id", 0)))),
Operator::And,
Arc::new(IsNotNullExpr::new(Arc::new(Column::new(
"right_distance",
3,
)))),
)),
));
// Create index mapping
let mut index_map = HashMap::new();
index_map.insert(1, 10);
index_map.insert(2, 11);
index_map.insert(3, 12);
// Reproject the expression
let reprojected = reproject_column_references(&expr, &index_map);
let reprojected_col = reprojected.as_any().downcast_ref::<BinaryExpr>().unwrap();
// The reprojected expression should be: ST_Intersects(left_geom[10], right_geom[11]) AND (IS NOT NULL(left_id[0]) AND IS NOT NULL(right_distance[12]))
assert_eq!(reprojected_col.op(), &Operator::And);
// Left side should be ST_Intersects
let left_side = reprojected_col.left();
let st_intersects = left_side
.as_any()
.downcast_ref::<ScalarFunctionExpr>()
.unwrap();
assert_eq!(st_intersects.fun().name(), "st_intersects");
let st_intersects_args = st_intersects.args();
assert_eq!(st_intersects_args.len(), 2);
// First arg should be left_geom with index 10
let left_geom_col = st_intersects_args[0]
.as_any()
.downcast_ref::<Column>()
.unwrap();
assert_eq!(left_geom_col.name(), "left_geom");
assert_eq!(left_geom_col.index(), 10);
// Second arg should be right_geom with index 11
let right_geom_col = st_intersects_args[1]
.as_any()
.downcast_ref::<Column>()
.unwrap();
assert_eq!(right_geom_col.name(), "right_geom");
assert_eq!(right_geom_col.index(), 11);
// Right side should be nested AND expression
let right_side = reprojected_col.right();
let nested_and = right_side.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(nested_and.op(), &Operator::And);
// Left part of nested AND should be IS NOT NULL(left_id[0])
let left_not_null = nested_and.left();
let left_is_not_null = left_not_null
.as_any()
.downcast_ref::<IsNotNullExpr>()
.unwrap();
let left_id_col = left_is_not_null
.arg()
.as_any()
.downcast_ref::<Column>()
.unwrap();
assert_eq!(left_id_col.name(), "left_id");
assert_eq!(left_id_col.index(), 0); // Should remain 0 (not remapped)
// Right part of nested AND should be IS NOT NULL(right_distance[12])
let right_not_null = nested_and.right();
let right_is_not_null = right_not_null
.as_any()
.downcast_ref::<IsNotNullExpr>()
.unwrap();
let right_distance_col = right_is_not_null
.arg()
.as_any()
.downcast_ref::<Column>()
.unwrap();
assert_eq!(right_distance_col.name(), "right_distance");
assert_eq!(right_distance_col.index(), 12); // Should be remapped from 3 to 12
}
#[test]
fn test_reproject_column_references_for_left_side() {
let column_indices = create_test_column_indices();
// Create expression referencing left side column (intermediate index 1)
let left_col_expr = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
// Reproject for left side
let reprojected =
reproject_column_references_for_side(&left_col_expr, &column_indices, JoinSide::Left);
let reprojected_col = reprojected.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(reprojected_col.index(), 1); // Should map to original left side index
}
#[test]
fn test_reproject_column_references_for_right_side() {
let column_indices = create_test_column_indices();
// Create expression referencing right side column (intermediate index 2)
let right_col_expr = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
// Reproject for right side
let reprojected =
reproject_column_references_for_side(&right_col_expr, &column_indices, JoinSide::Right);
let reprojected_col = reprojected.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(reprojected_col.index(), 0); // Should map to original right side index
}
#[test]
fn test_reproject_column_references_for_none_side() {
let column_indices = create_test_column_indices();
let expr = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
// Test JoinSide::None (should return original expression)
let none_reprojected =
reproject_column_references_for_side(&expr, &column_indices, JoinSide::None);
// Should be the same object
assert!(Arc::ptr_eq(&expr, &none_reprojected));
}
#[test]
fn test_match_relation_predicate_st_intersects() {
let column_indices = create_test_column_indices();
// Create ST_Intersects(left_geom, right_geom)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let st_intersects_udf = create_dummy_st_intersects_udf();
let args = vec![left_geom, right_geom];
let st_intersects = create_spatial_function_expr(st_intersects_udf, args);
let predicate = match_relation_predicate(&st_intersects, &column_indices);
assert!(predicate.is_some());
let pred = predicate.unwrap();
// Verify the relation type is Intersects
assert_eq!(pred.relation_type, SpatialRelationType::Intersects);
// Verify left argument is reprojected to left side (should reference index 1 on left side)
let left_arg_col = pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 1);
assert_eq!(left_arg_col.name(), "left_geom");
// Verify right argument is reprojected to right side (should reference index 0 on right side)
let right_arg_col = pred.right.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_arg_col.index(), 0);
assert_eq!(right_arg_col.name(), "right_geom");
}
#[test]
fn test_match_relation_predicate_st_within_inverted() {
let column_indices = create_test_column_indices();
// Create ST_Within(right_geom, left_geom) - this should be inverted to left, right order
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let st_within_udf = create_dummy_st_within_udf();
let args = vec![right_geom, left_geom]; // Note: right, left order
let st_within = create_spatial_function_expr(st_within_udf, args);
let predicate = match_relation_predicate(&st_within, &column_indices);
assert!(predicate.is_some());
let pred = predicate.unwrap();
// Verify the relation type is Contains (inverted from Within)
assert_eq!(pred.relation_type, SpatialRelationType::Contains);
// After inversion, left_arg should be the original left_geom
let left_arg_col = pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 1);
assert_eq!(left_arg_col.name(), "left_geom");
// After inversion, right_arg should be the original right_geom
let right_arg_col = pred.right.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_arg_col.index(), 0);
assert_eq!(right_arg_col.name(), "right_geom");
}
#[test]
fn test_match_relation_predicate_same_side_fails() {
let column_indices = create_test_column_indices();
// Create ST_Intersects(left_geom, left_id) - both from same side, should fail
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let st_intersects_udf = create_dummy_st_intersects_udf();
let args = vec![left_geom, left_id];
let st_intersects = create_spatial_function_expr(st_intersects_udf, args);
let predicate = match_relation_predicate(&st_intersects, &column_indices);
assert!(predicate.is_none()); // Should fail - both args from same side
}
#[test]
fn test_match_distance_predicate_st_dwithin() {
let column_indices = create_test_column_indices();
// Create ST_DWithin(left_geom, right_geom, 1000.0)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let distance =
Arc::new(Literal::new(ScalarValue::Float64(Some(1000.0)))) as Arc<dyn PhysicalExpr>;
let st_dwithin_udf = create_dummy_st_dwithin_udf();
let args = vec![left_geom, right_geom, distance];
let st_dwithin = create_spatial_function_expr(st_dwithin_udf, args);
let st_dwithin_expr = st_dwithin as Arc<dyn PhysicalExpr>;
let predicate = match_distance_predicate(&st_dwithin_expr, &column_indices);
assert!(predicate.is_some());
let pred = predicate.unwrap();
// Verify left argument is reprojected to left side
let left_arg_col = pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 1);
assert_eq!(left_arg_col.name(), "left_geom");
// Verify right argument is reprojected to right side
let right_arg_col = pred.right.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_arg_col.index(), 0);
assert_eq!(right_arg_col.name(), "right_geom");
// Verify distance is a literal with JoinSide::None
assert_eq!(pred.distance_side, datafusion_common::JoinSide::None);
let distance_literal = pred.distance.as_any().downcast_ref::<Literal>().unwrap();
match distance_literal.value() {
ScalarValue::Float64(Some(val)) => assert_eq!(val, &1000.0),
_ => panic!("Expected Float64 literal"),
}
}
#[test]
fn test_match_distance_predicate_st_distance_comparison() {
let column_indices = create_test_column_indices();
// Create ST_Distance(left_geom, right_geom) <= 1000.0
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let distance =
Arc::new(Literal::new(ScalarValue::Float64(Some(1000.0)))) as Arc<dyn PhysicalExpr>;
let st_distance_udf = create_dummy_st_distance_udf();
let st_distance_args = vec![left_geom, right_geom];
let st_distance = create_spatial_function_expr(st_distance_udf, st_distance_args);
let st_distance_expr = st_distance as Arc<dyn PhysicalExpr>;
// Create <= comparison
let comparison = Arc::new(BinaryExpr::new(st_distance_expr, Operator::LtEq, distance))
as Arc<dyn PhysicalExpr>;
let predicate = match_distance_predicate(&comparison, &column_indices);
assert!(predicate.is_some());
let pred = predicate.unwrap();
// Verify left and right arguments are correctly reprojected
let left_arg_col = pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 1);
assert_eq!(left_arg_col.name(), "left_geom");
let right_arg_col = pred.right.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_arg_col.index(), 0);
assert_eq!(right_arg_col.name(), "right_geom");
// Verify distance is a literal with JoinSide::None
assert_eq!(pred.distance_side, datafusion_common::JoinSide::None);
}
#[test]
fn test_match_distance_predicate_with_column_distance() {
let column_indices = create_test_column_indices();
// Create ST_DWithin(left_geom, right_geom, right_distance) - distance from right side
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let right_distance = Arc::new(Column::new("right_distance", 3)) as Arc<dyn PhysicalExpr>;
let st_dwithin_udf = create_dummy_st_dwithin_udf();
let args = vec![left_geom, right_geom, right_distance];
let st_dwithin = create_spatial_function_expr(st_dwithin_udf, args);
let st_dwithin_expr = st_dwithin as Arc<dyn PhysicalExpr>;
let predicate = match_distance_predicate(&st_dwithin_expr, &column_indices);
assert!(predicate.is_some());
let pred = predicate.unwrap();
// Verify left and right geometry arguments
let left_arg_col = pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 1);
assert_eq!(left_arg_col.name(), "left_geom");
let right_arg_col = pred.right.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_arg_col.index(), 0);
assert_eq!(right_arg_col.name(), "right_geom");
// Verify distance comes from right side
assert_eq!(pred.distance_side, datafusion_common::JoinSide::Right);
let distance_col = pred.distance.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(distance_col.index(), 1); // Should be reprojected to right side index 1
assert_eq!(distance_col.name(), "right_distance");
}
#[test]
fn test_extract_spatial_predicate_simple() {
let column_indices = create_test_column_indices();
// Test simple ST_Intersects
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let st_intersects_udf = create_dummy_st_intersects_udf();
let args = vec![left_geom, right_geom];
let st_intersects = create_spatial_function_expr(st_intersects_udf, args);
let st_intersects_expr = st_intersects as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&st_intersects_expr, &column_indices);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
let SpatialPredicate::Relation(rel_pred) = spatial_pred else {
panic!("Expected SpatialPredicate::Relation");
};
assert_eq!(
rel_pred
.left
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
1
);
assert_eq!(
rel_pred
.right
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
0
);
assert_eq!(rel_pred.relation_type, SpatialRelationType::Intersects);
assert!(remainder.is_none()); // No remainder for simple predicate
}
#[test]
fn test_extract_spatial_predicate_with_and() {
let column_indices = create_test_column_indices();
// Create ST_Intersects(left_geom, right_geom) AND left_id = 1
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let literal_one =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let st_intersects_udf = create_dummy_st_intersects_udf();
let st_intersects_args = vec![left_geom, right_geom];
let st_intersects = create_spatial_function_expr(st_intersects_udf, st_intersects_args);
let st_intersects_expr = st_intersects as Arc<dyn PhysicalExpr>;
let id_filter =
Arc::new(BinaryExpr::new(left_id, Operator::Eq, literal_one)) as Arc<dyn PhysicalExpr>;
let and_expr = Arc::new(BinaryExpr::new(
st_intersects_expr,
Operator::And,
id_filter,
)) as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&and_expr, &column_indices);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
let SpatialPredicate::Relation(rel_pred) = spatial_pred else {
panic!("Expected SpatialPredicate::Relation");
};
assert_eq!(
rel_pred
.left
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
1
);
assert_eq!(
rel_pred
.right
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
0
);
assert_eq!(rel_pred.relation_type, SpatialRelationType::Intersects);
assert!(remainder.is_some()); // Should have remainder (the id filter)
let remainder = remainder.unwrap();
// Remainder should be: left_id = 1
let remainder_binary = remainder.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(remainder_binary.op(), &Operator::Eq);
// Left side should be left_id column
let left_side = remainder_binary.left();
let left_col = left_side.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_col.name(), "left_id");
assert_eq!(left_col.index(), 0);
// Right side should be literal 1
let right_side = remainder_binary.right();
let literal = right_side.as_any().downcast_ref::<Literal>().unwrap();
match literal.value() {
ScalarValue::Int32(Some(val)) => assert_eq!(val, &1),
_ => panic!("Expected Int32(1) literal"),
}
}
#[test]
fn test_extract_spatial_predicate_distance_in_and() {
let column_indices = create_test_column_indices();
// Create left_id = 1 AND ST_DWithin(left_geom, right_geom, 1000.0)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let literal_one =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let distance =
Arc::new(Literal::new(ScalarValue::Float64(Some(1000.0)))) as Arc<dyn PhysicalExpr>;
let id_filter =
Arc::new(BinaryExpr::new(left_id, Operator::Eq, literal_one)) as Arc<dyn PhysicalExpr>;
let st_dwithin_udf = create_dummy_st_dwithin_udf();
let st_dwithin_args = vec![left_geom, right_geom, distance];
let st_dwithin = create_spatial_function_expr(st_dwithin_udf, st_dwithin_args);
let st_dwithin_expr = st_dwithin as Arc<dyn PhysicalExpr>;
let and_expr = Arc::new(BinaryExpr::new(id_filter, Operator::And, st_dwithin_expr))
as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&and_expr, &column_indices);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
let SpatialPredicate::Distance(dist_pred) = spatial_pred else {
panic!("Expected SpatialPredicate::Distance");
};
assert_eq!(
dist_pred
.left
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
1
);
assert_eq!(
dist_pred
.right
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
0
);
assert_eq!(
dist_pred
.distance
.as_any()
.downcast_ref::<Literal>()
.unwrap()
.value(),
&ScalarValue::Float64(Some(1000.0))
);
assert_eq!(dist_pred.distance_side, JoinSide::None);
assert!(remainder.is_some()); // Should have remainder (the id filter)
let remainder = remainder.unwrap();
// Remainder should be: left_id = 1
let remainder_binary = remainder.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(remainder_binary.op(), &Operator::Eq);
// Left side should be left_id column
let left_side = remainder_binary.left();
let left_col = left_side.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_col.name(), "left_id");
assert_eq!(left_col.index(), 0);
// Right side should be literal 1
let right_side = remainder_binary.right();
let literal = right_side.as_any().downcast_ref::<Literal>().unwrap();
match literal.value() {
ScalarValue::Int32(Some(val)) => assert_eq!(val, &1),
_ => panic!("Expected Int32(1) literal"),
}
}
#[test]
fn test_extract_spatial_predicate_no_spatial() {
let column_indices = create_test_column_indices();
// Create non-spatial expression: left_id = right_distance
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let right_distance = Arc::new(Column::new("right_distance", 3)) as Arc<dyn PhysicalExpr>;
let non_spatial = Arc::new(BinaryExpr::new(left_id, Operator::Eq, right_distance))
as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&non_spatial, &column_indices);
assert!(result.is_none()); // No spatial predicate found
}
#[test]
fn test_replace_join_filter_expr() {
let schema = create_test_schema();
let column_indices = create_test_column_indices();
// Create original join filter
let dummy_expr =
Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) as Arc<dyn PhysicalExpr>;
let original_filter = JoinFilter::new(dummy_expr, column_indices.clone(), schema.clone());
// Create new expression that only references some columns
let left_id = Arc::new(Column::new("right_distance", 3)) as Arc<dyn PhysicalExpr>;
let literal_one =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let new_expr =
Arc::new(BinaryExpr::new(left_id, Operator::Eq, literal_one)) as Arc<dyn PhysicalExpr>;
let new_filter = replace_join_filter_expr(&new_expr, &original_filter);
// The new filter should have fewer columns since it only references right_distance
assert_eq!(new_filter.column_indices().len(), 1);
assert_eq!(new_filter.schema().fields().len(), 1);
assert_eq!(
new_filter.column_indices()[0],
ColumnIndex {
index: 1,
side: JoinSide::Right,
}
);
let expr = new_filter.expression();
let binary_expr = expr.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(binary_expr.op(), &Operator::Eq);
assert_eq!(
binary_expr
.left()
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
0
);
}
#[test]
fn test_transform_join_filter_with_spatial_predicate() {
let schema = create_test_schema();
let column_indices = create_test_column_indices();
// Create ST_Intersects expression
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let st_intersects_udf = create_dummy_st_intersects_udf();
let args = vec![left_geom, right_geom];
let st_intersects = create_spatial_function_expr(st_intersects_udf, args);
let st_intersects_expr = st_intersects as Arc<dyn PhysicalExpr>;
let join_filter = JoinFilter::new(st_intersects_expr, column_indices, schema);
let result = transform_join_filter(&join_filter);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
assert!(matches!(spatial_pred, SpatialPredicate::Relation(_)));
assert!(remainder.is_none()); // No remainder for simple spatial predicate
}
#[test]
fn test_transform_join_filter_with_spatial_and_non_spatial() {
let schema = create_test_schema();
let column_indices = create_test_column_indices();
// Create ST_DWithin(left_geom, right_geom, 1000.0) AND left_id = 42
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let distance =
Arc::new(Literal::new(ScalarValue::Float64(Some(1000.0)))) as Arc<dyn PhysicalExpr>;
let literal_42 =
Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
let st_dwithin_udf = create_dummy_st_dwithin_udf();
let st_dwithin_args = vec![left_geom, right_geom, distance];
let st_dwithin = create_spatial_function_expr(st_dwithin_udf, st_dwithin_args);
let st_dwithin_expr = st_dwithin as Arc<dyn PhysicalExpr>;
let id_filter =
Arc::new(BinaryExpr::new(left_id, Operator::Eq, literal_42)) as Arc<dyn PhysicalExpr>;
let combined_expr = Arc::new(BinaryExpr::new(st_dwithin_expr, Operator::And, id_filter))
as Arc<dyn PhysicalExpr>;
let join_filter = JoinFilter::new(combined_expr, column_indices, schema);
let result = transform_join_filter(&join_filter);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
assert!(matches!(spatial_pred, SpatialPredicate::Distance(_)));
assert!(remainder.is_some()); // Should have remainder for the id filter
// The remainder should have fewer columns since it only references left_id
let remainder_filter = remainder.unwrap();
assert!(remainder_filter.column_indices().len() < join_filter.column_indices().len());
}
#[test]
fn test_complex_nested_spatial_and_filters() {
let schema = create_test_schema();
let column_indices = create_test_column_indices();
// Create (left_id > 10 AND ST_Intersects(left_geom, right_geom)) AND right_distance < 500.0
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let right_distance = Arc::new(Column::new("right_distance", 3)) as Arc<dyn PhysicalExpr>;
let literal_10 =
Arc::new(Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
let literal_500 =
Arc::new(Literal::new(ScalarValue::Float64(Some(500.0)))) as Arc<dyn PhysicalExpr>;
// Build left_id > 10
let left_filter =
Arc::new(BinaryExpr::new(left_id, Operator::Gt, literal_10)) as Arc<dyn PhysicalExpr>;
// Build ST_Intersects(left_geom, right_geom)
let st_intersects_udf = create_dummy_st_intersects_udf();
let st_intersects_args = vec![left_geom, right_geom];
let st_intersects = create_spatial_function_expr(st_intersects_udf, st_intersects_args);
let st_intersects_expr = st_intersects as Arc<dyn PhysicalExpr>;
// Build right_distance < 500.0
let right_filter = Arc::new(BinaryExpr::new(right_distance, Operator::Lt, literal_500))
as Arc<dyn PhysicalExpr>;
// Combine: (left_id > 10 AND ST_Intersects(left_geom, right_geom))
let inner_and = Arc::new(BinaryExpr::new(
left_filter,
Operator::And,
st_intersects_expr,
)) as Arc<dyn PhysicalExpr>;
// Final: (left_id > 10 AND ST_Intersects(left_geom, right_geom)) AND right_distance < 500.0
let complex_expr = Arc::new(BinaryExpr::new(inner_and, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;
let join_filter = JoinFilter::new(complex_expr, column_indices, schema);
let result = transform_join_filter(&join_filter);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
assert!(matches!(spatial_pred, SpatialPredicate::Relation(_)));
assert!(remainder.is_some()); // Should have remainder combining both non-spatial filters
let remainder = remainder.unwrap();
let binary_expr = remainder
.expression()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
assert_eq!(binary_expr.op(), &Operator::And);
let left_expr = binary_expr
.left()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
assert_eq!(left_expr.op(), &Operator::Gt);
let left_id_expr = left_expr.left().as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_id_expr.name(), "left_id");
assert_eq!(left_id_expr.index(), 0);
let right_expr = binary_expr
.right()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
assert_eq!(right_expr.op(), &Operator::Lt);
let right_distance_expr = right_expr.left().as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_distance_expr.name(), "right_distance");
assert_eq!(right_distance_expr.index(), 1);
assert_eq!(remainder.column_indices().len(), 2);
assert_eq!(
remainder.column_indices()[0],
ColumnIndex {
index: 0,
side: JoinSide::Left,
}
);
assert_eq!(
remainder.column_indices()[1],
ColumnIndex {
index: 1,
side: JoinSide::Right,
}
);
}
// Helper to create dummy ST_KNN UDF for testing
fn create_dummy_st_knn_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::from(SimpleScalarUDF::new(
"st_knn",
vec![
WKB_GEOMETRY.storage_type().clone(),
WKB_GEOMETRY.storage_type().clone(),
DataType::Int32,
DataType::Boolean,
],
DataType::Boolean,
datafusion_expr::Volatility::Immutable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))),
)))
}
#[test]
fn test_match_knn_predicate_basic() {
let column_indices = create_test_column_indices();
// Create ST_KNN(left_geom, right_geom, 5, false)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_some());
let pred = predicate.unwrap();
// Verify left argument is reprojected to left side
let left_arg_col = pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 1);
assert_eq!(left_arg_col.name(), "left_geom");
// Verify right argument is reprojected to right side
let right_arg_col = pred.right.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_arg_col.index(), 0);
assert_eq!(right_arg_col.name(), "right_geom");
// Verify k is literal value 5
assert_eq!(pred.k, 5);
// Verify use_spheroid is literal value false
assert!(!pred.use_spheroid);
}
#[test]
fn test_match_knn_predicate_inverted() {
let column_indices = create_test_column_indices();
// Create ST_KNN(right_geom, left_geom, 3, false) - this should be inverted to left, right order
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![right_geom, left_geom, k_literal, use_spheroid_literal]; // Note: right, left order
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_some());
let pred = predicate.unwrap();
// After inversion, left_arg should be the original left_geom
let left_arg_col = pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 0);
assert_eq!(left_arg_col.name(), "right_geom");
// After inversion, right_arg should be the original right_geom
let right_arg_col = pred.right.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(right_arg_col.index(), 1);
assert_eq!(right_arg_col.name(), "left_geom");
}
#[test]
fn test_match_knn_predicate_same_side_fails() {
let column_indices = create_test_column_indices();
// Create ST_KNN(left_geom, left_id, 5, false) - both from same side, should fail
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, left_id, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_none()); // Should fail - both args from same side
}
#[test]
fn test_match_knn_predicate_spheroid_true_accepted() {
let column_indices = create_test_column_indices();
// Create ST_KNN(left_geom, right_geom, 5, true) - should be accepted with use_spheroid=true
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_some()); // Should succeed - use_spheroid=true is now supported
let knn_pred = predicate.unwrap();
assert_eq!(knn_pred.k, 5);
assert!(knn_pred.use_spheroid); // Verify spheroid flag is set
}
#[test]
fn test_extract_spatial_predicate_knn() {
let column_indices = create_test_column_indices();
// Test simple ST_KNN
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let st_knn_expr = st_knn as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&st_knn_expr, &column_indices);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
let SpatialPredicate::KNearestNeighbors(knn_pred) = spatial_pred else {
panic!("Expected SpatialPredicate::KNearestNeighbors");
};
assert_eq!(
knn_pred
.left
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
1
);
assert_eq!(
knn_pred
.right
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
0
);
assert!(remainder.is_none()); // No remainder for simple predicate
}
#[test]
fn test_extract_spatial_predicate_knn_with_and() {
let column_indices = create_test_column_indices();
// Create ST_KNN(left_geom, right_geom, 5, false) AND left_id = 1
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let literal_one =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let st_knn_args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, st_knn_args);
let st_knn_expr = st_knn as Arc<dyn PhysicalExpr>;
let id_filter =
Arc::new(BinaryExpr::new(left_id, Operator::Eq, literal_one)) as Arc<dyn PhysicalExpr>;
let and_expr = Arc::new(BinaryExpr::new(st_knn_expr, Operator::And, id_filter))
as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&and_expr, &column_indices);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
let SpatialPredicate::KNearestNeighbors(knn_pred) = spatial_pred else {
panic!("Expected SpatialPredicate::KNearestNeighbors");
};
assert_eq!(
knn_pred
.left
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
1
);
assert_eq!(
knn_pred
.right
.as_any()
.downcast_ref::<Column>()
.unwrap()
.index(),
0
);
assert!(remainder.is_some()); // Should have remainder (the id filter)
let remainder = remainder.unwrap();
// Remainder should be: left_id = 1
let remainder_binary = remainder.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(remainder_binary.op(), &Operator::Eq);
// Left side should be left_id column
let left_side = remainder_binary.left();
let left_col = left_side.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_col.name(), "left_id");
assert_eq!(left_col.index(), 0);
// Right side should be literal 1
let right_side = remainder_binary.right();
let literal = right_side.as_any().downcast_ref::<Literal>().unwrap();
match literal.value() {
ScalarValue::Int32(Some(val)) => assert_eq!(val, &1),
_ => panic!("Expected Int32(1) literal"),
}
}
#[test]
fn test_transform_join_filter_with_knn_predicate() {
let schema = create_test_schema();
let column_indices = create_test_column_indices();
// Create ST_KNN expression
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let st_knn_expr = st_knn as Arc<dyn PhysicalExpr>;
let join_filter = JoinFilter::new(st_knn_expr, column_indices, schema);
let result = transform_join_filter(&join_filter);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
assert!(matches!(
spatial_pred,
SpatialPredicate::KNearestNeighbors(_)
));
assert!(remainder.is_none()); // No remainder for simple spatial predicate
}
#[test]
fn test_match_knn_predicate_insufficient_args() {
let column_indices = create_test_column_indices();
// Create ST_KNN with only 3 arguments (insufficient - needs 4)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal]; // Missing use_spheroid arg
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_none()); // Should fail due to insufficient arguments
}
#[test]
fn test_match_knn_predicate_wrong_function_name() {
let column_indices = create_test_column_indices();
// Create a function that's not ST_KNN
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_intersects_udf = create_dummy_st_intersects_udf(); // Wrong function
let args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_intersects = create_spatial_function_expr(st_intersects_udf, args);
let predicate = match_knn_predicate(&st_intersects, &column_indices);
assert!(predicate.is_none()); // Should fail due to wrong function name
}
#[test]
fn test_match_knn_predicate_non_column_arguments() {
let column_indices = create_test_column_indices();
// Create ST_KNN with literal geometry arguments (not column references)
let left_literal = Arc::new(Literal::new(ScalarValue::Binary(Some(
b"POINT(0 0)".to_vec(),
)))) as Arc<dyn PhysicalExpr>;
let right_literal = Arc::new(Literal::new(ScalarValue::Binary(Some(
b"POINT(1 1)".to_vec(),
)))) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_literal, right_literal, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_none()); // Should fail - geometry args are not column references
}
#[test]
fn test_match_knn_predicate_complex_k_expressions() {
let column_indices = create_test_column_indices();
// Create ST_KNN with complex k expression (column + literal)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let literal_two =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
let k_expr = Arc::new(BinaryExpr::new(left_id, Operator::Plus, literal_two))
as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_expr, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_none()); // Should fail - complex k expressions are no longer supported
}
#[test]
fn test_match_knn_predicate_complex_use_spheroid_expressions() {
let column_indices = create_test_column_indices();
// Create ST_KNN with complex use_spheroid expression (column comparison)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let literal_one =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_expr =
Arc::new(BinaryExpr::new(left_id, Operator::Gt, literal_one)) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal, use_spheroid_expr];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_none()); // Should fail - complex use_spheroid expressions are no longer supported
}
#[test]
fn test_match_knn_predicate_both_sides_in_k_expression() {
let column_indices = create_test_column_indices();
// Create ST_KNN with k expression that references both sides
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let right_distance = Arc::new(Column::new("right_distance", 3)) as Arc<dyn PhysicalExpr>;
// k expression that references both left and right sides
let k_expr = Arc::new(BinaryExpr::new(left_id, Operator::Plus, right_distance))
as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_expr, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
// Should fail because k expression references both sides, which is not allowed
assert!(predicate.is_none());
}
#[test]
fn test_extract_spatial_predicate_knn_with_multiple_clauses() {
let column_indices = create_test_column_indices();
// Create complex expression: ST_KNN(...) AND left_id > 0 AND right_distance < 100.0
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>;
let right_distance = Arc::new(Column::new("right_distance", 3)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let literal_zero =
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
let literal_hundred =
Arc::new(Literal::new(ScalarValue::Float64(Some(100.0)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let st_knn_args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, st_knn_args);
let st_knn_expr = st_knn as Arc<dyn PhysicalExpr>;
let id_filter =
Arc::new(BinaryExpr::new(left_id, Operator::Gt, literal_zero)) as Arc<dyn PhysicalExpr>;
let distance_filter = Arc::new(BinaryExpr::new(
right_distance,
Operator::Lt,
literal_hundred,
)) as Arc<dyn PhysicalExpr>;
// Build: ST_KNN(...) AND left_id > 0 AND right_distance < 100.0
let and1 = Arc::new(BinaryExpr::new(st_knn_expr, Operator::And, id_filter))
as Arc<dyn PhysicalExpr>;
let and2 = Arc::new(BinaryExpr::new(and1, Operator::And, distance_filter))
as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&and2, &column_indices);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
let SpatialPredicate::KNearestNeighbors(knn_pred) = spatial_pred else {
panic!("Expected SpatialPredicate::KNearestNeighbors");
};
// Verify KNN predicate parameters
assert_eq!(knn_pred.k, 3); // k is a literal value
assert!(!knn_pred.use_spheroid); // use_spheroid is a literal value
// Should have remainder with both filter conditions
assert!(remainder.is_some());
let remainder_expr = remainder.unwrap();
// Remainder should be: left_id > 0 AND right_distance < 100.0
let remainder_and = remainder_expr
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
assert_eq!(remainder_and.op(), &Operator::And);
}
#[test]
fn test_match_knn_predicate_nested_expressions() {
let column_indices = create_test_column_indices();
// Create ST_KNN with nested expressions for geometry arguments
let left_geom_col = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom_col = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
// Wrap in IsNotNull expressions (common pattern)
let left_geom = Arc::new(IsNotNullExpr::new(left_geom_col)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(IsNotNullExpr::new(right_geom_col)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);
assert!(predicate.is_some()); // Should succeed - nested expressions are allowed
let pred = predicate.unwrap();
// The wrapped columns should still be detected correctly
let left_is_not_null = pred.left.as_any().downcast_ref::<IsNotNullExpr>().unwrap();
let left_col = left_is_not_null
.arg()
.as_any()
.downcast_ref::<Column>()
.unwrap();
assert_eq!(left_col.name(), "left_geom");
assert_eq!(left_col.index(), 1); // reprojected index for left side
let right_is_not_null = pred.right.as_any().downcast_ref::<IsNotNullExpr>().unwrap();
let right_col = right_is_not_null
.arg()
.as_any()
.downcast_ref::<Column>()
.unwrap();
assert_eq!(right_col.name(), "right_geom");
assert_eq!(right_col.index(), 0); // reprojected index for right side
}
#[test]
fn test_extract_spatial_predicate_knn_no_remainder() {
let column_indices = create_test_column_indices();
// Test ST_KNN as standalone predicate (no AND/OR combinations)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(7)))) as Arc<dyn PhysicalExpr>;
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let st_knn_expr = st_knn as Arc<dyn PhysicalExpr>;
let result = extract_spatial_predicate(&st_knn_expr, &column_indices);
assert!(result.is_some());
let (spatial_pred, remainder) = result.unwrap();
let SpatialPredicate::KNearestNeighbors(knn_pred) = spatial_pred else {
panic!("Expected SpatialPredicate::KNearestNeighbors");
};
// Verify predicate details
assert_eq!(knn_pred.k, 7); // literal k
assert!(!knn_pred.use_spheroid); // literal use_spheroid
// Should have no remainder for standalone KNN predicate
assert!(remainder.is_none());
}
#[test]
fn test_transform_join_filter_with_complex_knn_predicate() {
let schema = create_test_schema();
let column_indices = create_test_column_indices();
// Create complex KNN expression with column-based k value
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn PhysicalExpr>;
let left_id = Arc::new(Column::new("left_id", 0)) as Arc<dyn PhysicalExpr>; // Use left_id as k
let use_spheroid_literal =
Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, left_id, use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let st_knn_expr = st_knn as Arc<dyn PhysicalExpr>;
let join_filter = JoinFilter::new(st_knn_expr, column_indices, schema);
let result = transform_join_filter(&join_filter);
assert!(result.is_none()); // Should fail - k must be a literal value
}
#[test]
fn test_is_spatial_predicate_supported() {
// Planar geometry field
let geom_field = WKB_GEOMETRY.to_storage_field("geom", false).unwrap();
let schema = Arc::new(Schema::new(vec![geom_field.clone()]));
let col_expr = Arc::new(Column::new("geom", 0)) as Arc<dyn PhysicalExpr>;
let rel_pred = RelationPredicate::new(
col_expr.clone(),
col_expr.clone(),
SpatialRelationType::Intersects,
);
let spatial_pred = SpatialPredicate::Relation(rel_pred);
assert!(super::is_spatial_predicate_supported(&spatial_pred, &schema, &schema).unwrap());
// Geography field (should NOT be supported)
let geog_field = WKB_GEOGRAPHY.to_storage_field("geog", false).unwrap();
let geog_schema = Arc::new(Schema::new(vec![geog_field.clone()]));
let geog_col_expr = Arc::new(Column::new("geog", 0)) as Arc<dyn PhysicalExpr>;
let rel_pred_geog = RelationPredicate::new(
geog_col_expr.clone(),
geog_col_expr.clone(),
SpatialRelationType::Intersects,
);
let spatial_pred_geog = SpatialPredicate::Relation(rel_pred_geog);
assert!(!super::is_spatial_predicate_supported(
&spatial_pred_geog,
&geog_schema,
&geog_schema
)
.unwrap());
}
#[test]
fn test_is_knn_predicate_supported() {
// ST_KNN(left, right)
let left_schema = Arc::new(Schema::new(vec![WKB_GEOMETRY
.to_storage_field("geom", false)
.unwrap()]));
let right_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
WKB_GEOMETRY.to_storage_field("geom", false).unwrap(),
]));
let left_col_expr = Arc::new(Column::new("geom", 0)) as Arc<dyn PhysicalExpr>;
let right_col_expr = Arc::new(Column::new("geom", 1)) as Arc<dyn PhysicalExpr>;
let knn_pred = SpatialPredicate::KNearestNeighbors(KNNPredicate::new(
left_col_expr.clone(),
right_col_expr.clone(),
5,
false,
JoinSide::Left,
));
assert!(
super::is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap()
);
// ST_KNN(right, left)
let knn_pred = SpatialPredicate::KNearestNeighbors(KNNPredicate::new(
right_col_expr.clone(),
left_col_expr.clone(),
5,
false,
JoinSide::Right,
));
assert!(
super::is_spatial_predicate_supported(&knn_pred, &left_schema, &right_schema).unwrap()
);
// ST_KNN with geography (should NOT be supported)
let left_geog_schema = Arc::new(Schema::new(vec![WKB_GEOGRAPHY
.to_storage_field("geog", false)
.unwrap()]));
assert!(!super::is_spatial_predicate_supported(
&knn_pred,
&left_geog_schema,
&right_schema
)
.unwrap());
let right_geog_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
WKB_GEOGRAPHY.to_storage_field("geog", false).unwrap(),
]));
assert!(!super::is_spatial_predicate_supported(
&knn_pred,
&left_schema,
&right_geog_schema
)
.unwrap());
}
#[test]
fn test_is_spatial_predicate() {
// Test 1: ST_ functions should return true
let st_intersects_udf = create_dummy_st_intersects_udf();
let st_intersects_expr = Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
func: st_intersects_udf,
args: vec![col("geom1"), col("geom2")],
});
assert!(super::is_spatial_predicate(&st_intersects_expr));
// ST_Distance(geom1, geom2) < 100 should return true
let st_distance_udf = create_dummy_st_distance_udf();
let st_distance_expr = Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
func: st_distance_udf,
args: vec![col("geom1"), col("geom2")],
});
let distance_lt_expr = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left: Box::new(st_distance_expr.clone()),
op: Operator::Lt,
right: Box::new(lit(100.0)),
});
assert!(super::is_spatial_predicate(&distance_lt_expr));
// ST_Distance(geom1, geom2) > 100 should return false
let distance_gt_expr = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left: Box::new(st_distance_expr.clone()),
op: Operator::Gt,
right: Box::new(lit(100.0)),
});
assert!(!super::is_spatial_predicate(&distance_gt_expr));
// AND expressions with spatial predicates should return true
let and_expr = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left: Box::new(st_intersects_expr.clone()),
op: Operator::And,
right: Box::new(col("id").eq(lit(1))),
});
assert!(super::is_spatial_predicate(&and_expr));
// Non-spatial expressions should return false
// Simple column comparison
let non_spatial_expr = col("id").eq(lit(1));
assert!(!super::is_spatial_predicate(&non_spatial_expr));
// Not a spatial relationship function
let non_st_func = Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
func: Arc::new(ScalarUDF::from(SimpleScalarUDF::new(
"st_non_spatial_relation_func",
vec![DataType::Int32],
DataType::Boolean,
datafusion_expr::Volatility::Immutable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))))),
))),
args: vec![col("id")],
});
assert!(!super::is_spatial_predicate(&non_st_func));
// AND expression with no spatial predicates
let non_spatial_and = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left: Box::new(col("id").eq(lit(1))),
op: Operator::And,
right: Box::new(col("name").eq(lit("test"))),
});
assert!(!super::is_spatial_predicate(&non_spatial_and));
}
}