[SPARK-48065][SQL] SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict
### What changes were proposed in this pull request?
If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, change KeyGroupedPartitioning.satisfies0(distribution) check from all clustering keys (here, join keys) being in partition keys, to the two sets overlapping.
### Why are the changes needed?
If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, then SPJ no longer triggers if there are more join keys than partition keys. But SPJ is supported in this case if flag is false.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added tests in KeyGroupedPartitioningSuite
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46325 from szehon-ho/fix_spj_less_join_key.
Authored-by: Szehon Ho <szehon.apache@gmail.com>
Signed-off-by: Chao Sun <chao@openai.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 2364130..43aba47 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -385,8 +385,9 @@
val attributes = expressions.flatMap(_.collectLeaves())
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
- // check that all join keys (required clustering keys) contained in partitioning
- requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) &&
+ // check that join keys (required clustering keys)
+ // overlap with partition keys (KeyGroupedPartitioning attributes)
+ requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index ec275fe..10a3244 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1227,6 +1227,66 @@
}
}
+ test("SPARK-48065: SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict") {
+ val table1 = "tab1e1"
+ val table2 = "table2"
+ val partition = Array(identity("id"))
+ createTable(table1, columns, partition)
+ sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+ "(2, 'bb', cast('2020-01-01' as timestamp)), " +
+ "(2, 'cc', cast('2020-01-01' as timestamp)), " +
+ "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+ "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+ "(3, 'ee', cast('2020-01-01' as timestamp)), " +
+ "(3, 'ee', cast('2020-01-01' as timestamp))")
+
+ createTable(table2, columns, partition)
+ sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+ "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+ "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+ "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+ "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+ "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+ "(2, 'ww', cast('2020-01-01' as timestamp))")
+
+ Seq(true, false).foreach { pushDownValues =>
+ Seq(true, false).foreach { partiallyClustered =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+ partiallyClustered.toString,
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint("t1", "t2")}
+ |t1.id AS id, t1.data AS t1data, t2.data AS t2data
+ |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+ |ON t1.id = t2.id AND t1.data = t2.data ORDER BY t1.id, t1data, t2data
+ |""".stripMargin)
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty, "SPJ should be triggered")
+
+ val scans = collectScans(df.queryExecution.executedPlan)
+ .map(_.inputRDD.partitions.length)
+ if (partiallyClustered) {
+ assert(scans == Seq(8, 8))
+ } else {
+ assert(scans == Seq(4, 4))
+ }
+ checkAnswer(df, Seq(
+ Row(3, "dd", "dd"),
+ Row(3, "dd", "dd"),
+ Row(3, "dd", "dd"),
+ Row(3, "dd", "dd")
+ ))
+ }
+ }
+ }
+ }
+
test("SPARK-44647: test join key is subset of cluster key " +
"with push values and partially-clustered") {
val table1 = "tab1e1"