[SPARK-48065][SQL] SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict

### What changes were proposed in this pull request?
If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, change KeyGroupedPartitioning.satisfies0(distribution) check from all clustering keys (here, join keys)  being in partition keys, to the two sets overlapping.

  ### Why are the changes needed?
If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, then SPJ no longer triggers if there are more join keys than partition keys. But SPJ is supported in this case if flag is false.

  ### Does this PR introduce _any_ user-facing change?
No

  ### How was this patch tested?
Added tests in KeyGroupedPartitioningSuite

 ### Was this patch authored or co-authored using generative AI tooling?
No

Closes #46325 from szehon-ho/fix_spj_less_join_key.

Authored-by: Szehon Ho <szehon.apache@gmail.com>
Signed-off-by: Chao Sun <chao@openai.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 2364130..43aba47 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -385,8 +385,9 @@
             val attributes = expressions.flatMap(_.collectLeaves())
 
             if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
-              // check that all join keys (required clustering keys) contained in partitioning
-              requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) &&
+              // check that join keys (required clustering keys)
+              // overlap with partition keys (KeyGroupedPartitioning attributes)
+              requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
                   expressions.forall(_.collectLeaves().size == 1)
             } else {
               attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index ec275fe..10a3244 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1227,6 +1227,66 @@
     }
   }
 
+  test("SPARK-48065: SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+    val partition = Array(identity("id"))
+    createTable(table1, columns, partition)
+    sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+        "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+        "(2, 'bb', cast('2020-01-01' as timestamp)), " +
+        "(2, 'cc', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'ee', cast('2020-01-01' as timestamp)), " +
+        "(3, 'ee', cast('2020-01-01' as timestamp))")
+
+    createTable(table2, columns, partition)
+    sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+        "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+        "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+        "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+        "(2, 'ww', cast('2020-01-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(true, false).foreach { partiallyClustered =>
+        withSQLConf(
+          SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+            partiallyClustered.toString,
+          SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
+          val df = sql(
+            s"""
+               |${selectWithMergeJoinHint("t1", "t2")}
+               |t1.id AS id, t1.data AS t1data, t2.data AS t2data
+               |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+               |ON t1.id = t2.id AND t1.data = t2.data ORDER BY t1.id, t1data, t2data
+               |""".stripMargin)
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.isEmpty, "SPJ should be triggered")
+
+          val scans = collectScans(df.queryExecution.executedPlan)
+            .map(_.inputRDD.partitions.length)
+          if (partiallyClustered) {
+            assert(scans == Seq(8, 8))
+          } else {
+            assert(scans == Seq(4, 4))
+          }
+          checkAnswer(df, Seq(
+            Row(3, "dd", "dd"),
+            Row(3, "dd", "dd"),
+            Row(3, "dd", "dd"),
+            Row(3, "dd", "dd")
+          ))
+        }
+      }
+    }
+  }
+
   test("SPARK-44647: test join key is subset of cluster key " +
       "with push values and partially-clustered") {
     val table1 = "tab1e1"