Use projection
diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
index cb98e69..1e539d1 100644
--- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
@@ -154,6 +154,7 @@
join_type,
mode,
null_equals_null,
+ projection,
..
}) = plan_any.downcast_ref::<HashJoinExec>()
{
@@ -169,6 +170,7 @@
join_type,
PartitionMode::Partitioned,
*null_equals_null,
+ projection.clone(),
)?) as Arc<dyn ExecutionPlan>)
};
Some(reorder_partitioned_join_keys(
@@ -541,6 +543,7 @@
join_type,
mode,
null_equals_null,
+ projection,
..
}) = plan_any.downcast_ref::<HashJoinExec>()
{
@@ -570,6 +573,7 @@
join_type,
PartitionMode::Partitioned,
*null_equals_null,
+ projection.clone(),
)?))
} else {
Ok(plan)
@@ -1123,6 +1127,7 @@
join_type,
PartitionMode::Partitioned,
false,
+ None,
)
.unwrap(),
)
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs
index a9dec73..ac24c62 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -137,6 +137,7 @@
&swap_join_type(*hash_join.join_type()),
partition_mode,
hash_join.null_equals_null(),
+ None,
)?;
if matches!(
hash_join.join_type(),
@@ -333,6 +334,7 @@
hash_join.join_type(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
+ hash_join.projection.clone(),
)?)))
}
}
@@ -344,6 +346,7 @@
hash_join.join_type(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
+ hash_join.projection.clone(),
)?))),
(false, true) => {
if supports_swap(*hash_join.join_type()) {
@@ -371,6 +374,7 @@
hash_join.join_type(),
PartitionMode::Partitioned,
hash_join.null_equals_null(),
+ hash_join.projection.clone(),
)?))
}
}
@@ -495,6 +499,7 @@
&JoinType::Left,
PartitionMode::CollectLeft,
false,
+ None,
)
.unwrap();
@@ -543,6 +548,7 @@
&JoinType::Left,
PartitionMode::CollectLeft,
false,
+ None,
)
.unwrap();
@@ -594,6 +600,7 @@
&join_type,
PartitionMode::Partitioned,
false,
+ None,
)
.unwrap();
@@ -659,6 +666,7 @@
&JoinType::Inner,
PartitionMode::CollectLeft,
false,
+ None,
)
.unwrap();
let child_schema = child_join.schema();
@@ -675,6 +683,7 @@
&JoinType::Left,
PartitionMode::CollectLeft,
false,
+ None,
)
.unwrap();
@@ -712,6 +721,7 @@
&JoinType::Inner,
PartitionMode::CollectLeft,
false,
+ None,
)
.unwrap();
@@ -937,6 +947,7 @@
&JoinType::Inner,
PartitionMode::Auto,
false,
+ None,
)
.unwrap();
diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
index caae774..cfe3455 100644
--- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
+++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
@@ -644,6 +644,7 @@
&t.initial_join_type,
t.initial_mode,
false,
+ None,
)?;
let initial_hash_join_state = PipelineStatePropagator {
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index a3c553c..b38811e 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -122,6 +122,8 @@
column_indices: Vec<ColumnIndex>,
/// If null_equals_null is true, null == null else null != null
pub(crate) null_equals_null: bool,
+ /// Optional output projection
+ pub projection: Option<Vec<Column>>,
}
impl HashJoinExec {
@@ -136,6 +138,7 @@
join_type: &JoinType,
partition_mode: PartitionMode,
null_equals_null: bool,
+ projection: Option<Vec<Column>>,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
@@ -148,7 +151,7 @@
check_join_is_valid(&left_schema, &right_schema, &on)?;
let (schema, column_indices) =
- build_join_schema(&left_schema, &right_schema, join_type);
+ build_join_schema(&left_schema, &right_schema, join_type, projection);
let random_state = RandomState::with_seeds(0, 0, 0, 0);
@@ -165,6 +168,7 @@
metrics: ExecutionPlanMetricsSet::new(),
column_indices,
null_equals_null,
+ projection,
})
}
@@ -337,6 +341,7 @@
&self.join_type,
self.mode,
self.null_equals_null,
+ self.projection,
)?))
}
@@ -1358,6 +1363,7 @@
join_type,
PartitionMode::CollectLeft,
null_equals_null,
+ None,
)
}
@@ -1377,6 +1383,7 @@
join_type,
PartitionMode::CollectLeft,
null_equals_null,
+ None,
)
}
@@ -1431,6 +1438,7 @@
join_type,
PartitionMode::Partitioned,
null_equals_null,
+ None,
)?;
let columns = columns(&join.schema());
@@ -3164,6 +3172,7 @@
&join_type,
PartitionMode::Partitioned,
false,
+ None,
)?;
let stream = join.execute(1, task_ctx)?;
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index 6586456..5a2b0e8 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -106,7 +106,7 @@
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &[])?;
let (schema, column_indices) =
- build_join_schema(&left_schema, &right_schema, join_type);
+ build_join_schema(&left_schema, &right_schema, join_type, None);
Ok(NestedLoopJoinExec {
left,
right,
diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
index bc8c686..324f858 100644
--- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
@@ -177,7 +177,7 @@
};
let schema =
- Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
+ Arc::new(build_join_schema(&left_schema, &right_schema, &join_type, None).0);
Ok(Self {
left,
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index b46aba2..7df848e 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -295,7 +295,7 @@
// Build the join schema from the left and right schemas:
let (schema, column_indices) =
- build_join_schema(&left_schema, &right_schema, join_type);
+ build_join_schema(&left_schema, &right_schema, join_type, None);
// Initialize the random state for the join operation:
let random_state = RandomState::with_seeds(0, 0, 0, 0);
@@ -1862,6 +1862,7 @@
join_type,
PartitionMode::Partitioned,
null_equals_null,
+ None,
)?;
let mut batches = vec![];
@@ -3026,7 +3027,7 @@
// Build the join schema from the left and right schemas
let (schema, join_column_indices) =
- build_join_schema(&left_schema, &right_schema, &join_type);
+ build_join_schema(&left_schema, &right_schema, &join_type, None);
let join_schema = Arc::new(schema);
// Sort information for MemoryExec
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs
index 627bdee..7722013 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -350,6 +350,7 @@
left: &Schema,
right: &Schema,
join_type: &JoinType,
+ projection: Option<Vec<Column>>,
) -> (Schema, Vec<ColumnIndex>) {
let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
@@ -1197,7 +1198,7 @@
];
for (left_in, right_in, join_type, left_out, right_out) in cases {
- let (schema, _) = build_join_schema(left_in, right_in, &join_type);
+ let (schema, _) = build_join_schema(left_in, right_in, &join_type, None);
let expected_fields = left_out
.fields()
diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs
index 7556620..525a6e3 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -901,6 +901,7 @@
join_type,
null_equals_null,
schema: join_schema,
+ projection,
..
}) => {
let null_equals_null = *null_equals_null;
@@ -990,6 +991,8 @@
})
.collect::<Result<join_utils::JoinOn>>()?;
+ let projection: Option<Vec<Column>> = projection.map(|proj|proj.iter().enumerate().map(|col|Column::new(col.name, 0)).collect());
+
let join_filter = match filter {
Some(expr) => {
// Extract columns from filter expression and saved in a HashSet
@@ -1095,6 +1098,7 @@
join_type,
partition_mode,
null_equals_null,
+ projection.clone(),
)?))
} else {
Ok(Arc::new(HashJoinExec::try_new(
@@ -1105,6 +1109,7 @@
join_type,
PartitionMode::CollectLeft,
null_equals_null,
+ projection.clone(),
)?))
}
}
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 3d34c08..34ae0eb 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -721,7 +721,7 @@
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
.collect();
let join_schema =
- build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+ build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
Ok(Self::from(LogicalPlan::Join(Join {
left: Arc::new(self.plan),
@@ -732,6 +732,7 @@
join_constraint: JoinConstraint::On,
schema: DFSchemaRef::new(join_schema),
null_equals_null,
+ projection: None,
})))
}
@@ -754,7 +755,7 @@
let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
let join_schema =
- build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+ build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
let mut join_on: Vec<(Expr, Expr)> = vec![];
let mut filters: Option<Expr> = None;
for (l, r) in &on {
@@ -796,6 +797,7 @@
join_constraint: JoinConstraint::Using,
schema: DFSchemaRef::new(join_schema),
null_equals_null: false,
+ projection: None,
})))
}
}
@@ -1012,7 +1014,7 @@
.collect::<Result<Vec<_>>>()?;
let join_schema =
- build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+ build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
Ok(Self::from(LogicalPlan::Join(Join {
left: Arc::new(self.plan),
@@ -1023,6 +1025,7 @@
join_constraint: JoinConstraint::On,
schema: DFSchemaRef::new(join_schema),
null_equals_null: false,
+ projection: None,
})))
}
@@ -1038,6 +1041,7 @@
left: &DFSchema,
right: &DFSchema,
join_type: &JoinType,
+ projection: Option<&Vec<Column>>,
) -> Result<DFSchema> {
fn nullify_fields(fields: &[DFField]) -> Vec<DFField> {
fields
@@ -1049,51 +1053,65 @@
let right_fields = right.fields();
let left_fields = left.fields();
- let fields: Vec<DFField> = match join_type {
- JoinType::Inner => {
- // left then right
- left_fields
+ let fields = {
+ if let Some(projection) = projection {
+ projection
.iter()
- .chain(right_fields.iter())
- .cloned()
- .collect()
- }
- JoinType::Left => {
- // left then right, right set to nullable in case of not matched scenario
- left_fields
- .iter()
- .chain(&nullify_fields(right_fields))
- .cloned()
- .collect()
- }
- JoinType::Right => {
- // left then right, left set to nullable in case of not matched scenario
- nullify_fields(left_fields)
- .iter()
- .chain(right_fields.iter())
- .cloned()
- .collect()
- }
- JoinType::Full => {
- // left then right, all set to nullable in case of not matched scenario
- nullify_fields(left_fields)
- .iter()
- .chain(&nullify_fields(right_fields))
- .cloned()
- .collect()
- }
- JoinType::LeftSemi | JoinType::LeftAnti => {
- // Only use the left side for the schema
- left_fields.clone()
- }
- JoinType::RightSemi | JoinType::RightAnti => {
- // Only use the right side for the schema
- right_fields.clone()
+ .map(|col| {
+ left.field_from_column(col)
+ .or_else(|_| right.field_from_column(col))
+ .cloned()
+ })
+ .collect::<Result<Vec<DFField>>>()?
+ } else {
+ match join_type {
+ JoinType::Inner => {
+ // left then right
+ left_fields
+ .iter()
+ .chain(right_fields.iter())
+ .cloned()
+ .collect()
+ }
+ JoinType::Left => {
+ // left then right, right set to nullable in case of not matched scenario
+ left_fields
+ .iter()
+ .chain(&nullify_fields(right_fields))
+ .cloned()
+ .collect()
+ }
+ JoinType::Right => {
+ // left then right, left set to nullable in case of not matched scenario
+ nullify_fields(left_fields)
+ .iter()
+ .chain(right_fields.iter())
+ .cloned()
+ .collect()
+ }
+ JoinType::Full => {
+ // left then right, all set to nullable in case of not matched scenario
+ nullify_fields(left_fields)
+ .iter()
+ .chain(&nullify_fields(right_fields))
+ .cloned()
+ .collect()
+ }
+ JoinType::LeftSemi | JoinType::LeftAnti => {
+ // Only use the left side for the schema
+ left_fields.clone()
+ }
+ JoinType::RightSemi | JoinType::RightAnti => {
+ // Only use the right side for the schema
+ right_fields.clone()
+ }
+ }
}
};
let mut metadata = left.metadata().clone();
metadata.extend(right.metadata().clone());
+
DFSchema::new_with_metadata(fields, metadata)
}
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index ab45047..d451791 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1660,6 +1660,8 @@
pub schema: DFSchemaRef,
/// If null_equals_null is true, null == null else null != null
pub null_equals_null: bool,
+ /// optional projection
+ pub projection: Option<Vec<Column>>,
}
impl Join {
@@ -1681,8 +1683,12 @@
.zip(column_on.1.into_iter())
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
.collect();
- let join_schema =
- build_join_schema(left.schema(), right.schema(), &original_join.join_type)?;
+ let join_schema = build_join_schema(
+ left.schema(),
+ right.schema(),
+ &original_join.join_type,
+ original_join.projection.as_ref(),
+ )?;
Ok(Join {
left,
@@ -1693,6 +1699,7 @@
join_constraint: original_join.join_constraint,
schema: Arc::new(join_schema),
null_equals_null: original_join.null_equals_null,
+ projection: None,
})
}
}
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 2b6fc57..2ae116e 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -845,10 +845,15 @@
join_constraint,
on,
null_equals_null,
+ projection,
..
}) => {
- let schema =
- build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
+ let schema = build_join_schema(
+ inputs[0].schema(),
+ inputs[1].schema(),
+ join_type,
+ projection.as_ref(),
+ )?;
let equi_expr_count = on.len();
assert!(expr.len() >= equi_expr_count);
@@ -881,6 +886,7 @@
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,
+ projection: projection.clone(),
}))
}
LogicalPlan::CrossJoin(_) => {
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs
index 533566a..e32311d 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -214,6 +214,7 @@
left_input.schema(),
right_input.schema(),
&JoinType::Inner,
+ None,
)?);
return Ok(LogicalPlan::Join(Join {
@@ -225,6 +226,7 @@
filter: None,
schema: join_schema,
null_equals_null: false,
+ projection: None,
}));
}
}
@@ -233,6 +235,7 @@
left_input.schema(),
right.schema(),
&JoinType::Inner,
+ None,
)?);
Ok(LogicalPlan::CrossJoin(CrossJoin {
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs
index e4d57f0..e2df09d 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -105,6 +105,7 @@
filter: join.filter.clone(),
schema: join.schema.clone(),
null_equals_null: join.null_equals_null,
+ projection: join.projection.clone(),
});
let new_plan = plan.with_new_inputs(&[new_join])?;
Ok(Some(new_plan))
diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs
index 20b9c62..41db207 100644
--- a/datafusion/optimizer/src/extract_equijoin_predicate.rs
+++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs
@@ -55,6 +55,7 @@
join_constraint,
schema,
null_equals_null,
+ projection,
}) => {
let left_schema = left.schema();
let right_schema = right.schema();
@@ -80,6 +81,7 @@
join_constraint: *join_constraint,
schema: schema.clone(),
null_equals_null: *null_equals_null,
+ projection: projection.clone(),
})
});
diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs
index 6703a1d..42915b2 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -266,6 +266,7 @@
join_constraint: join.join_constraint,
schema: join.schema.clone(),
null_equals_null: join.null_equals_null,
+ projection: join.projection.clone(),
})
}
}
diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs
index 4773a94..65409b0 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -101,6 +101,10 @@
for e in projection.expr.iter() {
expr_to_columns(e, &mut push_columns)?;
}
+
+ // Keep columns to use for join output projection
+ let output_columns = push_columns.clone();
+
for (l, r) in join.on.iter() {
expr_to_columns(l, &mut push_columns)?;
expr_to_columns(r, &mut push_columns)?;
@@ -119,9 +123,14 @@
join.right.schema(),
join.right.clone(),
)?;
- let new_join = child_plan.with_new_inputs(&[new_left, new_right])?;
- generate_plan!(projection_is_empty, plan, new_join)
+ let mut join = join.clone();
+
+ join.left = Arc::new(new_left);
+ join.right = Arc::new(new_right);
+ join.projection = Some(output_columns.into_iter().collect());
+
+ generate_plan!(projection_is_empty, plan, LogicalPlan::Join(join))
}
LogicalPlan::CrossJoin(join) => {
// collect column in on/filter in join and projection.