blob: e102f08e3f85ace23eb9152328e1da7b5fe393f5 [file] [log] [blame]
diff --git a/pom.xml b/pom.xml
index a4b1b2c3c9f..6a532749978 100644
--- a/pom.xml
+++ b/pom.xml
@@ -147,6 +147,8 @@
<chill.version>0.10.0</chill.version>
<ivy.version>2.5.2</ivy.version>
<oro.version>2.0.8</oro.version>
+ <spark.version.short>4.0</spark.version.short>
+ <comet.version>0.7.0-SNAPSHOT</comet.version>
<!--
If you change codahale.metrics.version, you also need to change
the link to metrics.dropwizard.io in docs/monitoring.md.
@@ -2848,6 +2850,25 @@
<artifactId>arpack</artifactId>
<version>${netlib.ludovic.dev.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.datafusion</groupId>
+ <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
+ <version>${comet.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
<!-- SPARK-16484 add `datasketches-java` for support Datasketches HllSketch -->
<dependency>
<groupId>org.apache.datasketches</groupId>
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 19f6303be36..6c0e77882e6 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -77,6 +77,10 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.datafusion</groupId>
+ <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
+ </dependency>
<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 466e4cf8131..798f118464c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -1376,6 +1376,14 @@ object SparkSession extends Logging {
}
}
+ private def loadCometExtension(sparkContext: SparkContext): Seq[String] = {
+ if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled)) {
+ Seq("org.apache.comet.CometSparkSessionExtensions")
+ } else {
+ Seq.empty
+ }
+ }
+
/**
* Initialize extensions specified in [[StaticSQLConf]]. The classes will be applied to the
* extensions passed into this function.
@@ -1385,7 +1393,8 @@ object SparkSession extends Logging {
extensions: SparkSessionExtensions): SparkSessionExtensions = {
val extensionConfClassNames = sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
.getOrElse(Seq.empty)
- extensionConfClassNames.foreach { extensionConfClassName =>
+ val extensionClassNames = extensionConfClassNames ++ loadCometExtension(sparkContext)
+ extensionClassNames.foreach { extensionConfClassName =>
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.getConstructor().newInstance()
@@ -1420,4 +1429,12 @@ object SparkSession extends Logging {
}
}
}
+
+ /**
+ * Whether Comet extension is enabled
+ */
+ def isCometEnabled: Boolean = {
+ val v = System.getenv("ENABLE_COMET")
+ v == null || v.toBoolean
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 7c45b02ee84..9f2b608c9f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
@@ -67,6 +68,7 @@ private[execution] object SparkPlanInfo {
// dump the file scan metadata (e.g file path) to event log
val metadata = plan match {
case fileScan: FileSourceScanLike => fileScan.metadata
+ case cometScan: CometScanExec => cometScan.metadata
case _ => Map[String, String]()
}
new SparkPlanInfo(
diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
index 619eb4470e9..8465382a007 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
@@ -1,5 +1,8 @@
-- test cases for collation support
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
-- Create a test table with data
create table t1(utf8_binary string collate utf8_binary, utf8_binary_lcase string collate utf8_binary_lcase) using parquet;
insert into t1 values('aaa', 'aaa');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
index 7aef901da4f..f3d6e18926d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
@@ -2,3 +2,4 @@
--SET spark.sql.adaptive.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
index eeb2180f7a5..afd1b5ec289 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
@@ -1,5 +1,6 @@
--SET spark.sql.cbo.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
CREATE TABLE explain_temp1(a INT, b INT) USING PARQUET;
CREATE TABLE explain_temp2(c INT, d INT) USING PARQUET;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql
index 698ca009b4f..57d774a3617 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql
@@ -1,6 +1,7 @@
--SET spark.sql.codegen.wholeStage = true
--SET spark.sql.adaptive.enabled = false
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
-- Test tables
CREATE table explain_temp1 (key int, val int) USING PARQUET;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
index 3a409eea348..26e9aaf215c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
@@ -6,6 +6,9 @@
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/int4.sql
--
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
CREATE TABLE INT4_TBL(f1 int) USING parquet;
-- [SPARK-28023] Trim the string when cast string type to other types
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
index fac23b4a26f..98b12ae5ccc 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -6,6 +6,10 @@
-- Test int8 64-bit integers.
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/int8.sql
--
+
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
CREATE TABLE INT8_TBL(q1 bigint, q2 bigint) USING parquet;
-- PostgreSQL implicitly casts string literals to data with integral types, but
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
index 0efe0877e9b..f9df0400c99 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
@@ -6,6 +6,9 @@
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql
--
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
-- load test data
CREATE TABLE test_having (a int, b int, c string, d string) USING parquet;
INSERT INTO test_having VALUES (0, 1, 'XXXX', 'A');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql b/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql
index e803254ea64..74db78aee38 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql
@@ -1,6 +1,9 @@
-- This test suits check the spark.sql.viewSchemaBindingMode configuration.
-- It can be DISABLED and COMPENSATION
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
-- Verify the default binding is true
SET spark.sql.legacy.viewSchemaBindingMode;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql b/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql
index 21a3ce1e122..f4762ab98f0 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql
@@ -1,5 +1,9 @@
-- This test suite checks the WITH SCHEMA COMPENSATION clause
-- Disable ANSI mode to ensure we are forcing it explicitly in the CASTS
+
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
SET spark.sql.ansi.enabled = false;
-- In COMPENSATION views get invalidated if the type can't cast
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index d023fb82185..0f4f03bda6c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan, SparkPlanInfo}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation}
import org.apache.spark.sql.execution.columnar._
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -519,7 +519,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
df.collect()
}
assert(
- collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.size == expected)
+ collect(df.queryExecution.executedPlan) {
+ case _: ShuffleExchangeLike => 1 }.size == expected)
}
test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
@@ -1658,7 +1659,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
_.nodeName.contains("TableCacheQueryStage"))
val aqeNode = findNodeInSparkPlanInfo(inMemoryScanNode.get,
_.nodeName.contains("AdaptiveSparkPlan"))
- aqeNode.get.children.head.nodeName == "AQEShuffleRead"
+ aqeNode.get.children.head.nodeName == "AQEShuffleRead" ||
+ (aqeNode.get.children.head.nodeName.contains("WholeStageCodegen") &&
+ aqeNode.get.children.head.children.head.nodeName == "ColumnarToRow" &&
+ aqeNode.get.children.head.children.head.children.head.nodeName == "InputAdapter" &&
+ aqeNode.get.children.head.children.head.children.head.children.head.nodeName ==
+ "AQEShuffleRead")
}
withTempView("t0", "t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 620ee430cab..9d383a4bff9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -813,7 +813,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)
val exchangePlans = collect(aggPlan) {
- case shuffle: ShuffleExchangeExec => shuffle
+ case shuffle: ShuffleExchangeLike => shuffle
}
assert(exchangePlans.length == 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index f6fd6b501d7..11870c85d82 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -435,7 +435,9 @@ class DataFrameJoinSuite extends QueryTest
withTempDatabase { dbName =>
withTable(table1Name, table2Name) {
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ "spark.comet.enabled" -> "false") {
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 760ee802608..b77133ffd37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -36,11 +36,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.comet.CometBroadcastExchangeExec
import org.apache.spark.sql.connector.FakeV2Provider
import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -1401,7 +1402,7 @@ class DataFrameSuite extends QueryTest
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
- case e: ShuffleExchangeExec => atFirstAgg = false
+ case e: ShuffleExchangeLike => atFirstAgg = false
case _ =>
}
}
@@ -1591,7 +1592,7 @@ class DataFrameSuite extends QueryTest
checkAnswer(join, df)
assert(
collect(join.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => true }.size === 1)
+ case _: ShuffleExchangeLike => true }.size === 1)
assert(
collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1)
val broadcasted = broadcast(join)
@@ -1599,10 +1600,12 @@ class DataFrameSuite extends QueryTest
checkAnswer(join2, df)
assert(
collect(join2.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => true }.size == 1)
+ case _: ShuffleExchangeLike => true }.size == 1)
assert(
collect(join2.queryExecution.executedPlan) {
- case e: BroadcastExchangeExec => true }.size === 1)
+ case e: BroadcastExchangeExec => true
+ case _: CometBroadcastExchangeExec => true
+ }.size === 1)
assert(
collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4)
}
@@ -2000,7 +2003,7 @@ class DataFrameSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
- case h: ShuffleExchangeExec => h
+ case h: ShuffleExchangeLike => h
}
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 16a493b5290..3f0b70e2d59 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
@@ -2360,7 +2360,7 @@ class DatasetSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
- case h: ShuffleExchangeExec => h
+ case h: ShuffleExchangeLike => h
}
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index 2c24cc7d570..65163e35666 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
@@ -262,6 +263,9 @@ abstract class DynamicPartitionPruningSuiteBase
case s: BatchScanExec => s.runtimeFilters.collect {
case d: DynamicPruningExpression => d.child
}
+ case s: CometScanExec => s.partitionFilters.collect {
+ case d: DynamicPruningExpression => d.child
+ }
case _ => Nil
}
}
@@ -665,7 +669,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("partition pruning in broadcast hash joins with aliases") {
+ test("partition pruning in broadcast hash joins with aliases",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Given("alias with simple join condition, using attribute names only")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
@@ -755,7 +760,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("partition pruning in broadcast hash joins") {
+ test("partition pruning in broadcast hash joins",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Given("disable broadcast pruning and disable subquery duplication")
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
@@ -990,7 +996,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("different broadcast subqueries with identical children") {
+ test("different broadcast subqueries with identical children",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
withTable("fact", "dim") {
spark.range(100).select(
@@ -1027,7 +1034,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("avoid reordering broadcast join keys to match input hash partitioning") {
+ test("avoid reordering broadcast join keys to match input hash partitioning",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTable("large", "dimTwo", "dimThree") {
@@ -1187,7 +1195,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("Make sure dynamic pruning works on uncorrelated queries") {
+ test("Make sure dynamic pruning works on uncorrelated queries",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
"""
@@ -1215,7 +1224,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " +
- "canonicalization and exchange reuse") {
+ "canonicalization and exchange reuse",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(
@@ -1238,7 +1248,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("Plan broadcast pruning only when the broadcast can be reused") {
+ test("Plan broadcast pruning only when the broadcast can be reused",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Given("dynamic pruning filter on the build side")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
@@ -1279,7 +1290,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-32659: Fix the data issue when pruning DPP on non-atomic type") {
+ test("SPARK-32659: Fix the data issue when pruning DPP on non-atomic type",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Seq(NO_CODEGEN, CODEGEN_ONLY).foreach { mode =>
Seq(true, false).foreach { pruning =>
withSQLConf(
@@ -1311,7 +1323,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-32817: DPP throws error when the broadcast side is empty") {
+ test("SPARK-32817: DPP throws error when the broadcast side is empty",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
@@ -1424,7 +1437,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-34637: DPP side broadcast query stage is created firstly") {
+ test("SPARK-34637: DPP side broadcast query stage is created firstly",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
""" WITH v as (
@@ -1455,7 +1469,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-35568: Fix UnsupportedOperationException when enabling both AQE and DPP") {
+ test("SPARK-35568: Fix UnsupportedOperationException when enabling both AQE and DPP",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
val df = sql(
"""
|SELECT s.store_id, f.product_id
@@ -1471,7 +1486,8 @@ abstract class DynamicPartitionPruningSuiteBase
checkAnswer(df, Row(3, 2) :: Row(3, 2) :: Row(3, 2) :: Row(3, 2) :: Nil)
}
- test("SPARK-36444: Remove OptimizeSubqueries from batch of PartitionPruning") {
+ test("SPARK-36444: Remove OptimizeSubqueries from batch of PartitionPruning",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
val df = sql(
"""
@@ -1486,7 +1502,7 @@ abstract class DynamicPartitionPruningSuiteBase
}
test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " +
- "pruning") {
+ "pruning", IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
Seq(
"f.store_id = 1" -> false,
@@ -1558,7 +1574,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-38674: Remove useless deduplicate in SubqueryBroadcastExec") {
+ test("SPARK-38674: Remove useless deduplicate in SubqueryBroadcastExec",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withTable("duplicate_keys") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
Seq[(Int, String)]((1, "NL"), (1, "NL"), (3, "US"), (3, "US"), (3, "US"))
@@ -1589,7 +1606,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-39338: Remove dynamic pruning subquery if pruningKey's references is empty") {
+ test("SPARK-39338: Remove dynamic pruning subquery if pruningKey's references is empty",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
val df = sql(
"""
@@ -1618,7 +1636,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-39217: Makes DPP support the pruning side has Union") {
+ test("SPARK-39217: Makes DPP support the pruning side has Union",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
val df = sql(
"""
@@ -1730,6 +1749,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
+ case s: CometScanExec =>
+ s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid")))
case _ => false
}
assert(scanOption.isDefined)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index b2aaaceb26a..625522f36ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -467,7 +467,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}
- test("Explain formatted output for scan operator for datasource V2") {
+ test("Explain formatted output for scan operator for datasource V2",
+ IgnoreComet("Comet explain output is different")) {
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
@@ -545,7 +546,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}
-class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite {
+// Ignored when Comet is enabled. Comet changes expected query plans.
+class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite
+ with IgnoreCometSuite {
import testImplicits._
test("SPARK-35884: Explain Formatted") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 49a33d1c925..9a540abd0c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.{ExplainMode, FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
@@ -951,6 +952,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(bJoinExec.isEmpty)
val smJoinExec = collect(joinedDF.queryExecution.executedPlan) {
case smJoin: SortMergeJoinExec => smJoin
+ case smJoin: CometSortMergeJoinExec => smJoin
}
assert(smJoinExec.nonEmpty)
}
@@ -1011,6 +1013,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
@@ -1052,6 +1055,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
@@ -1236,6 +1240,8 @@ class FileBasedDataSourceSuite extends QueryTest
val filters = df.queryExecution.executedPlan.collect {
case f: FileSourceScanLike => f.dataFilters
case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
+ case b: CometScanExec => b.dataFilters
+ case b: CometBatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
}.flatten
assert(filters.contains(GreaterThan(scan.logicalPlan.output.head, Literal(5L))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
new file mode 100644
index 00000000000..4b31bea33de
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalactic.source.Position
+import org.scalatest.Tag
+
+import org.apache.spark.sql.test.SQLTestUtils
+
+/**
+ * Tests with this tag will be ignored when Comet is enabled (e.g., via `ENABLE_COMET`).
+ */
+case class IgnoreComet(reason: String) extends Tag("DisableComet")
+
+/**
+ * Helper trait that disables Comet for all tests regardless of default config values.
+ */
+trait IgnoreCometSuite extends SQLTestUtils {
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
+ (implicit pos: Position): Unit = {
+ if (isCometEnabled) {
+ ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
+ } else {
+ super.test(testName, testTags: _*)(testFun)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
index 027477a8291..f2568916e88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -442,7 +442,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
}
test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " +
- "on the same column") {
+ "on the same column",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " +
@@ -451,7 +452,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
}
test("Runtime bloom filter join: add bloom filter if dpp filter exists on " +
- "a different column") {
+ "a different column",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
assertRewroteWithBloomFilter("select * from bf5part join bf2 on " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 53e47f428c3..a55d8f0c161 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
@@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleHashJoins = collect(executedPlan) {
case s: ShuffledHashJoinExec => s
+ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec]
}
assert(shuffleHashJoins.size == 1)
assert(shuffleHashJoins.head.buildSide == buildSide)
@@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleMergeJoins = collect(executedPlan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c
}
assert(shuffleMergeJoins.size == 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index fcb937d82ba..fafe8e8d08b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION}
-import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, FilterExec, InputAdapter, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins._
@@ -805,7 +806,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
}
- test("test SortMergeJoin (with spill)") {
+ test("test SortMergeJoin (with spill)",
+ IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
@@ -931,10 +933,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val physical = df.queryExecution.sparkPlan
val physicalJoins = physical.collect {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
val executed = df.queryExecution.executedPlan
val executedJoins = collect(executed) {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
// This only applies to the above tested queries, in which a child SortMergeJoin always
// contains the SortOrder required by its parent SortMergeJoin. Thus, SortExec should never
@@ -1180,9 +1184,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
.groupBy($"k1").count()
.queryExecution.executedPlan
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// No extra shuffle before aggregate
- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true }.size === 2)
})
}
@@ -1199,10 +1205,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2)
assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1)
// No extra sort before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 3)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 3)
})
// Test shuffled hash join
@@ -1212,10 +1219,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// No extra sort before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortExec | _: CometSortExec => true }.size === 3)
})
}
@@ -1306,12 +1316,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
inputDFs.foreach { case (df1, df2, joinExprs) =>
val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full")
assert(collect(smjDF.queryExecution.executedPlan) {
- case _: SortMergeJoinExec => true }.size === 1)
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 1)
val smjResult = smjDF.collect()
val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full")
assert(collect(shjDF.queryExecution.executedPlan) {
- case _: ShuffledHashJoinExec => true }.size === 1)
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
}
@@ -1370,12 +1380,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()
val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
@@ -1386,12 +1398,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()
val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
@@ -1435,13 +1449,19 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
assert(shjCodegenDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometHashJoinExec))) =>
+ true
+ case WholeStageCodegenExec(ColumnarToRowExec(
+ InputAdapter(CometProjectExec(_, _, _, _, _: CometHashJoinExec, _)))) => true
}.size === 1)
checkAnswer(shjCodegenDF, Seq.empty)
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val shjNonCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
assert(shjNonCodegenDF.queryExecution.executedPlan.collect {
- case _: ShuffledHashJoinExec => true }.size === 1)
+ case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
+ }.size === 1)
checkAnswer(shjNonCodegenDF, Seq.empty)
}
}
@@ -1489,7 +1509,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
// Have shuffle before aggregation
- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true }.size === 1)
}
def getJoinQuery(selectExpr: String, joinType: String): String = {
@@ -1518,9 +1539,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// No extra sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 5)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 5)
}
// Test output ordering is not preserved
@@ -1529,9 +1553,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// Have sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 6)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 6)
}
// Test singe partition
@@ -1541,7 +1568,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
|FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
|""".stripMargin)
val plan = fullJoinDF.queryExecution.executedPlan
- assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true}.size == 1)
checkAnswer(fullJoinDF, Row(100))
}
}
@@ -1586,6 +1614,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
Seq(semiJoinDF, antiJoinDF).foreach { df =>
assert(collect(df.queryExecution.executedPlan) {
case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true
+ case j: CometHashJoinExec
+ if j.originalPlan.asInstanceOf[ShuffledHashJoinExec].ignoreDuplicatedKey ==
+ ignoreDuplicatedKey => true
}.size == 1)
}
}
@@ -1630,14 +1661,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") {
def check(plan: SparkPlan): Unit = {
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 1)
}
dupStreamSideColTest("MERGE", check)
}
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SHJ)") {
def check(plan: SparkPlan): Unit = {
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
+ }.size === 1)
}
dupStreamSideColTest("SHUFFLE_HASH", check)
}
@@ -1773,7 +1810,8 @@ class ThreadLeakInSortMergeJoinSuite
sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20))
}
- test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)") {
+ test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)",
+ IgnoreComet("Comet SMJ doesn't spill yet")) {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
index 34c6c49bc49..f5dea07a213 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
@@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest
* }}}
*/
// scalastyle:on line.size.limit
-trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite {
+trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite with IgnoreCometSuite {
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 56c364e2084..fc3abd7cdc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1510,7 +1510,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
}
- test("external sorting updates peak execution memory") {
+ test("external sorting updates peak execution memory",
+ IgnoreComet("TODO: native CometSort does not update peak execution memory")) {
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
@@ -4454,7 +4455,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" +
- " when WSCG is off") {
+ " when WSCG is off",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
@@ -4475,7 +4477,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-39175: Query context of Cast should be serialized to executors" +
- " when WSCG is off") {
+ " when WSCG is off",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
@@ -4502,7 +4505,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +
- "be serialized to executors when WSCG is off") {
+ "be serialized to executors when WSCG is off",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 4d38e360f43..3c272af0b62 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -223,6 +223,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
@@ -281,6 +283,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
@@ -319,6 +323,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
val session = SparkSession.builder()
.master("local[1]")
.config(COLUMN_BATCH_SIZE.key, 2)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ .config("spark.comet.enabled", false)
.withExtensions { extensions =>
extensions.injectColumnar(session =>
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 68f14f13bbd..174636cefb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project, Sort, Union}
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.FileScanRDD
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -1541,6 +1542,12 @@ class SubquerySuite extends QueryTest
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
_.files.forall(_.urlEncodedPath.contains("p=0"))))
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(
+ fs @ CometScanExec(_, _, _, partitionFilters, _, _, _, _, _, _)))) =>
+ partitionFilters.exists(ExecSubqueryExpression.hasSubquery) &&
+ fs.inputRDDs().forall(
+ _.asInstanceOf[FileScanRDD].filePartitions.forall(
+ _.files.forall(_.urlEncodedPath.contains("p=0"))))
case _ => false
})
}
@@ -2106,7 +2113,7 @@ class SubquerySuite extends QueryTest
df.collect()
val exchanges = collect(df.queryExecution.executedPlan) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(exchanges.size === 1)
}
@@ -2672,18 +2679,26 @@ class SubquerySuite extends QueryTest
def checkFileSourceScan(query: String, answer: Seq[Row]): Unit = {
val df = sql(query)
checkAnswer(df, answer)
- val fileSourceScanExec = collect(df.queryExecution.executedPlan) {
- case f: FileSourceScanExec => f
+ val dataSourceScanExec = collect(df.queryExecution.executedPlan) {
+ case f: FileSourceScanLike => f
+ case c: CometScanExec => c
}
sparkContext.listenerBus.waitUntilEmpty()
- assert(fileSourceScanExec.size === 1)
- val scalarSubquery = fileSourceScanExec.head.dataFilters.flatMap(_.collect {
- case s: ScalarSubquery => s
- })
+ assert(dataSourceScanExec.size === 1)
+ val scalarSubquery = dataSourceScanExec.head match {
+ case f: FileSourceScanLike =>
+ f.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ case c: CometScanExec =>
+ c.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ }
assert(scalarSubquery.length === 1)
assert(scalarSubquery.head.plan.isInstanceOf[ReusedSubqueryExec])
- assert(fileSourceScanExec.head.metrics("numFiles").value === 1)
- assert(fileSourceScanExec.head.metrics("numOutputRows").value === answer.size)
+ assert(dataSourceScanExec.head.metrics("numFiles").value === 1)
+ assert(dataSourceScanExec.head.metrics("numOutputRows").value === answer.size)
}
withTable("t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 1de535df246..cc7ffc4eeb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -26,6 +26,7 @@ import test.org.apache.spark.sql.connector._
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
@@ -36,7 +37,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
import org.apache.spark.sql.execution.SortExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
-import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
@@ -278,13 +279,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val groupByColJ = df.groupBy($"j").agg(sum($"i"))
checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
assert(collectFirst(groupByColJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined)
val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*"))
checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined)
}
}
@@ -344,10 +345,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val (shuffleExpected, sortExpected) = groupByExpects
assert(collectFirst(groupBy.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined === shuffleExpected)
assert(collectFirst(groupBy.queryExecution.executedPlan) {
case e: SortExec => e
+ case c: CometSortExec => c
}.isDefined === sortExpected)
}
@@ -362,10 +364,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val (shuffleExpected, sortExpected) = windowFuncExpects
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined === shuffleExpected)
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
case e: SortExec => e
+ case c: CometSortExec => c
}.isDefined === sortExpected)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
index c6060dcdd51..78fd88be9bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
@@ -188,7 +189,11 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
val df = spark.read.format(format).load(path.getCanonicalPath)
checkAnswer(df, inputData.toDF())
assert(
- df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec]))
+ df.queryExecution.executedPlan.exists {
+ case _: FileSourceScanExec | _: CometScanExec => true
+ case _ => false
+ }
+ )
}
} finally {
spark.listenerManager.unregister(listener)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 10a32441b6c..5e5d763ee70 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.comet.CometSortMergeJoinExec
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.distributions.Distributions
@@ -31,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
@@ -298,13 +299,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50)))
}
- private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
+ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = {
// here we skip collecting shuffle operators that are not associated with SMJ
collect(plan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c.originalPlan
}.flatMap(smj =>
collect(smj) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 12d5f13df01..816d1518c5b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -21,7 +21,7 @@ package org.apache.spark.sql.connector
import java.sql.Date
import java.util.Collections
-import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, IgnoreCometSuite, Row}
import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical
@@ -45,7 +45,8 @@ import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.tags.SlowSQLTest
@SlowSQLTest
-class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase {
+class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase
+ with IgnoreCometSuite {
import testImplicits._
before {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 8238eabc7fe..c960fd75a9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -31,7 +31,7 @@ import org.mockito.Mockito.{mock, spy, when}
import org.scalatest.time.SpanSugar._
import org.apache.spark._
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, KryoData, QueryTest, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, IgnoreComet, KryoData, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
@@ -266,7 +266,8 @@ class QueryExecutionErrorsSuite
}
test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " +
- "compatibility with Spark 2.4/3.2 in reading/writing dates") {
+ "compatibility with Spark 2.4/3.2 in reading/writing dates",
+ IgnoreComet("Comet doesn't completely support datetime rebase mode yet")) {
// Fail to read ancient datetime values.
withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
index 418ca3430bb..eb8267192f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
@@ -23,7 +23,7 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.internal.SQLConf
@@ -195,7 +195,7 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest {
}
}
- test("FileScan description") {
+ test("FileScan description", IgnoreComet("Comet doesn't use BatchScan")) {
Seq("json", "orc", "parquet").foreach { format =>
withTempPath { path =>
val dir = path.getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
index 743ec41dbe7..9f30d6c8e04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
@@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv
case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child)
case p: ProjectExec => isScanPlanTree(p.child)
case f: FilterExec => isScanPlanTree(f.child)
+ // Comet produces scan plan tree like:
+ // ColumnarToRow
+ // +- ReusedExchange
+ case _: ReusedExchangeExec => false
case _: LeafExecNode => true
case _ => false
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 15de4c5cc5b..6a85dfb6883 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, DataFrame, Row}
+import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -36,7 +36,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
-class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
+// Ignore this suite when Comet is enabled. This suite tests the Spark planner and Comet planner
+// comes out with too many difference. Simply ignoring this suite for now.
+class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper with IgnoreCometSuite {
import testImplicits._
setupTestData()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index 3608e7c9207..6a05de2b9ac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import scala.collection.mutable
import scala.io.Source
-import org.apache.spark.sql.{AnalysisException, Dataset, ExtendedExplainGenerator, FastOperator}
+import org.apache.spark.sql.{AnalysisException, Dataset, ExtendedExplainGenerator, FastOperator, IgnoreComet}
import org.apache.spark.sql.catalyst.{QueryPlanningTracker, QueryPlanningTrackerCallback}
import org.apache.spark.sql.catalyst.analysis.CurrentNamespace
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@@ -383,7 +383,7 @@ class QueryExecutionSuite extends SharedSparkSession {
}
}
- test("SPARK-47289: extended explain info") {
+ test("SPARK-47289: extended explain info", IgnoreComet("Comet plan extended info is different")) {
val concat = new PlanStringConcat()
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
index b5bac8079c4..a3731888e12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest, Row}
+import org.apache.spark.sql.comet.CometProjectExec
import org.apache.spark.sql.connector.SimpleWritableDataSource
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
@@ -34,7 +35,10 @@ abstract class RemoveRedundantProjectsSuiteBase
private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = {
withClue(df.queryExecution) {
val plan = df.queryExecution.executedPlan
- val actual = collectWithSubqueries(plan) { case p: ProjectExec => p }.size
+ val actual = collectWithSubqueries(plan) {
+ case p: ProjectExec => p
+ case p: CometProjectExec => p
+ }.size
assert(actual == expected)
}
}
@@ -112,7 +116,8 @@ abstract class RemoveRedundantProjectsSuiteBase
assertProjectExec(query, 1, 3)
}
- test("join with ordering requirement") {
+ test("join with ordering requirement",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
val query = "select * from (select key, a, c, b from testView) as t1 join " +
"(select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50"
assertProjectExec(query, 2, 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
index 005e764cc30..92ec088efab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning}
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.joins.ShuffledJoin
import org.apache.spark.sql.internal.SQLConf
@@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase
private def checkNumSorts(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
- assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count)
+ assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count)
}
private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
index 47679ed7865..9ffbaecb98e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.internal.SQLConf
@@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase
private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) {
- case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
+ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index 966f4e74712..8017e22d7f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
+import org.apache.spark.sql.comet.CometColumnarToRowExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -134,7 +135,10 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
- df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
+ df.queryExecution.executedPlan.collectFirst {
+ case p: ColumnarToRowExec => p
+ case p: CometColumnarToRowExec => p
+ }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 3aaf61ffba4..4130ece2283 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.rdd.MapPartitionsWithEvaluatorRDD
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
+import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -172,6 +173,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+ case _: CometHashJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4)))
@@ -180,6 +182,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+ case _: CometHashJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
@@ -206,6 +209,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
@@ -216,6 +221,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
@@ -226,6 +233,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
@@ -237,6 +246,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
@@ -258,6 +269,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(rightJoinUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(rightJoinUniqueDf, Seq(Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9),
@@ -269,6 +282,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(leftJoinUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(leftJoinUniqueDf, Seq(Row(0, null), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4)))
assert(leftJoinUniqueDf.count() === 5)
@@ -278,6 +293,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(rightJoinNonUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(rightJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8)))
@@ -287,6 +304,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(leftJoinNonUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(leftJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null),
@@ -298,6 +317,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(rightJoinWithNonEquiDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(rightJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 7),
Row(1, 10), Row(2, 2), Row(2, 8), Row(null, 3), Row(null, 4), Row(null, 5)))
@@ -308,6 +329,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(leftJoinWithNonEquiDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(leftJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 7),
Row(1, 10), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null)))
@@ -318,6 +341,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoRightJoinsDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoRightJoinsDf, Seq(Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
@@ -327,6 +352,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoLeftJoinsDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoLeftJoinsDf,
Seq(Row(0, null, null), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
@@ -343,6 +370,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneLeftOuterJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_outer")
assert(oneLeftOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneLeftOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, null),
Row(5, null), Row(6, null), Row(7, null), Row(8, null), Row(9, null)))
@@ -351,6 +379,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneRightOuterJoinDF = df2.join(df3.hint("SHUFFLE_MERGE"), $"k2" === $"k3", "right_outer")
assert(oneRightOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneRightOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(null, 4),
Row(null, 5)))
@@ -360,6 +389,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, null, 4), Row(5, null, 5),
@@ -375,6 +405,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0), Row(1), Row(2), Row(3)))
@@ -382,8 +413,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi")
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3)))
}
@@ -397,6 +428,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(4), Row(5), Row(6), Row(7), Row(8), Row(9)))
@@ -404,8 +436,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
.join(df3.hint("SHUFFLE_MERGE"), $"k1" === $"k3", "left_anti")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(6), Row(7), Row(8), Row(9)))
}
@@ -538,7 +570,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]))
+ p.asInstanceOf[WholeStageCodegenExec].collect {
+ case _: SortExec => true
+ case _: CometSortExec => true
+ }.nonEmpty))
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
@@ -718,7 +753,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.write.mode(SaveMode.Overwrite).parquet(path)
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
- SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") {
+ SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true",
+ // Disable Comet native execution because this checks wholestage codegen.
+ "spark.comet.exec.enabled" -> "false") {
val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i")
val df = spark.read.parquet(path).selectExpr(projection: _*)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index a7efd0aa75e..baae0967a2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -28,11 +28,13 @@ import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Row, SparkSession, Strategy}
+import org.apache.spark.sql.{DataFrame, Dataset, IgnoreComet, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnaryExecNode, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike}
@@ -120,6 +122,7 @@ class AdaptiveQueryExecSuite
private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = {
collect(plan) {
case j: BroadcastHashJoinExec => j
+ case j: CometBroadcastHashJoinExec => j.originalPlan.asInstanceOf[BroadcastHashJoinExec]
}
}
@@ -132,30 +135,39 @@ class AdaptiveQueryExecSuite
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec =>
+ assert(j.originalPlan.isInstanceOf[SortMergeJoinExec])
+ j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
}
private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = {
collect(plan) {
case j: ShuffledHashJoinExec => j
+ case j: CometHashJoinExec => j.originalPlan.asInstanceOf[ShuffledHashJoinExec]
}
}
private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
collect(plan) {
case j: BaseJoinExec => j
+ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
+ case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
+ case c: CometBroadcastHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
}
}
private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = {
collect(plan) {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}
}
private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = {
collect(plan) {
case agg: BaseAggregateExec => agg
+ case agg: CometHashAggregateExec => agg.originalPlan.asInstanceOf[BaseAggregateExec]
}
}
@@ -205,6 +217,7 @@ class AdaptiveQueryExecSuite
val parts = rdd.partitions
assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
}
+
assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead))
}
@@ -213,7 +226,7 @@ class AdaptiveQueryExecSuite
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
@@ -229,7 +242,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics
+ // checkNumLocalShuffleReads(adaptivePlan)
}
}
@@ -256,7 +270,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Reuse the parallelism of coalesced shuffle in local shuffle read") {
+ test("Reuse the parallelism of coalesced shuffle in local shuffle read",
+ IgnoreComet("Comet shuffle changes shuffle partition size")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -288,7 +303,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Reuse the default parallelism in local shuffle read") {
+ test("Reuse the default parallelism in local shuffle read",
+ IgnoreComet("Comet shuffle changes shuffle partition size")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -302,7 +318,8 @@ class AdaptiveQueryExecSuite
val localReads = collect(adaptivePlan) {
case read: AQEShuffleReadExec if read.isLocalRead => read
}
- assert(localReads.length == 2)
+ // Comet shuffle changes shuffle metrics
+ assert(localReads.length == 1)
val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD]
val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD]
// the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2
@@ -327,7 +344,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
- assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
+ assert(find(plan) { case p =>
+ p.isInstanceOf[SortMergeJoinExec] || p.isInstanceOf[CometSortMergeJoinExec]
+ }.isDefined)
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
@@ -341,7 +360,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
- assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(find(plan) { case p =>
+ p.isInstanceOf[BroadcastHashJoinExec] || p.isInstanceOf[CometBroadcastHashJoinExec]
+ }.isDefined)
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
@@ -351,7 +372,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Scalar subquery") {
+ test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -366,7 +387,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Scalar subquery in later stages") {
+ test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -382,7 +403,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins") {
+ test("multiple joins", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -427,7 +448,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins with aggregate") {
+ test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -472,7 +493,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins with aggregate 2") {
+ test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
@@ -518,7 +539,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Exchange reuse") {
+ test("Exchange reuse", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -537,7 +558,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Exchange reuse with subqueries") {
+ test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -568,7 +589,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics,
+ // so we can't check the number of local shuffle reads.
+ // checkNumLocalShuffleReads(adaptivePlan)
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
@@ -589,7 +612,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics,
+ // so we can't check the number of local shuffle reads.
+ // checkNumLocalShuffleReads(adaptivePlan)
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
@@ -598,7 +623,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Broadcast exchange reuse across subqueries") {
+ test("Broadcast exchange reuse across subqueries",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
@@ -693,7 +719,8 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// There is still a SMJ, and its two shuffles can't apply local read.
- checkNumLocalShuffleReads(adaptivePlan, 2)
+ // Comet shuffle changes shuffle metrics
+ // checkNumLocalShuffleReads(adaptivePlan, 2)
}
}
@@ -815,7 +842,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-29544: adaptive skew join with different join types") {
+ test("SPARK-29544: adaptive skew join with different join types",
+ IgnoreComet("Comet shuffle has different partition metrics")) {
Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") {
findTopLevelSortMergeJoin(plan)
@@ -1123,7 +1151,8 @@ class AdaptiveQueryExecSuite
}
}
- test("metrics of the shuffle read") {
+ test("metrics of the shuffle read",
+ IgnoreComet("Comet shuffle changes the metrics")) {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT key FROM testData GROUP BY key")
@@ -1718,7 +1747,7 @@ class AdaptiveQueryExecSuite
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
assert(collect(adaptivePlan) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}.length == 1)
}
}
@@ -1798,7 +1827,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-33551: Do not use AQE shuffle read for repartition") {
+ test("SPARK-33551: Do not use AQE shuffle read for repartition",
+ IgnoreComet("Comet shuffle changes partition size")) {
def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
find(plan) {
case s: ShuffleExchangeLike =>
@@ -1983,6 +2013,9 @@ class AdaptiveQueryExecSuite
def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = {
assert(collect(ds.queryExecution.executedPlan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ case c: CometShuffleExchangeExec
+ if c.originalPlan.shuffleOrigin == origin &&
+ c.originalPlan.numPartitions == 2 => c
}.size == 1)
ds.collect()
val plan = ds.queryExecution.executedPlan
@@ -1991,6 +2024,9 @@ class AdaptiveQueryExecSuite
}.isEmpty)
assert(collect(plan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ case c: CometShuffleExchangeExec
+ if c.originalPlan.shuffleOrigin == origin &&
+ c.originalPlan.numPartitions == 2 => c
}.size == 1)
checkAnswer(ds, testData)
}
@@ -2147,7 +2183,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35264: Support AQE side shuffled hash join formula") {
+ test("SPARK-35264: Support AQE side shuffled hash join formula",
+ IgnoreComet("Comet shuffle changes the partition size")) {
withTempView("t1", "t2") {
def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
Seq("100", "100000").foreach { size =>
@@ -2233,7 +2270,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") {
+ test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -2332,7 +2370,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
val shuffles1 = collect(adaptive1) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles1.size == 3)
// shuffles1.head is the top-level shuffle under the Aggregate operator
@@ -2345,7 +2383,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
s"JOIN skewData2 ON key1 = key2")
val shuffles2 = collect(adaptive2) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
if (hasRequiredDistribution) {
assert(shuffles2.size == 3)
@@ -2379,7 +2417,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35794: Allow custom plugin for cost evaluator") {
+ test("SPARK-35794: Allow custom plugin for cost evaluator",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
CostEvaluator.instantiate(
classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
intercept[IllegalArgumentException] {
@@ -2510,7 +2549,8 @@ class AdaptiveQueryExecSuite
}
test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " +
- "resulting in potentially inaccurate data") {
+ "resulting in potentially inaccurate data",
+ IgnoreComet("too many shuffle partitions causes Java heap OOM")) {
withTable("t3") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -2545,6 +2585,7 @@ class AdaptiveQueryExecSuite
val (_, adaptive) = runAdaptiveAndVerifyResult(query)
assert(adaptive.collect {
case sort: SortExec => sort
+ case sort: CometSortExec => sort
}.size == 1)
val read = collect(adaptive) {
case read: AQEShuffleReadExec => read
@@ -2562,7 +2603,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-37357: Add small partition factor for rebalance partitions") {
+ test("SPARK-37357: Add small partition factor for rebalance partitions",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
@@ -2674,7 +2716,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value2 = value3")
val shuffles1 = collect(adaptive1) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles1.size == 4)
val smj1 = findTopLevelSortMergeJoin(adaptive1)
@@ -2685,7 +2727,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value1 = value3")
val shuffles2 = collect(adaptive2) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles2.size == 4)
val smj2 = findTopLevelSortMergeJoin(adaptive2)
@@ -2911,6 +2953,7 @@ class AdaptiveQueryExecSuite
}.size == (if (firstAccess) 1 else 0))
assert(collect(initialExecutedPlan) {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.size == (if (firstAccess) 2 else 0))
assert(collect(initialExecutedPlan) {
case i: InMemoryTableScanLike => i
@@ -2923,6 +2966,7 @@ class AdaptiveQueryExecSuite
}.isEmpty)
assert(collect(finalExecutedPlan) {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.isEmpty)
assert(collect(initialExecutedPlan) {
case i: InMemoryTableScanLike => i
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 0a0b23d1e60..5685926250f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Concat
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
@@ -868,6 +869,7 @@ abstract class SchemaPruningSuite
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case scan: FileSourceScanExec => scan.requiredSchema
+ case scan: CometScanExec => scan.requiredSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index 04a7b4834f4..8cab62ce4ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.datasources
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{IgnoreComet, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.execution.{QueryExecution, SortExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.internal.SQLConf
@@ -225,6 +226,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _),
@@ -272,6 +274,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _),
@@ -305,7 +308,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
}
}
- test("v1 write with AQE changing SMJ to BHJ") {
+ test("v1 write with AQE changing SMJ to BHJ",
+ IgnoreComet("TODO: Comet SMJ to BHJ by AQE")) {
withPlannedWrite { enabled =>
withTable("t") {
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 5c118ac12b7..fede1a94488 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
import org.mockito.Mockito.{mock, when}
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.{DataFrame, IgnoreCometSuite, QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.functions.col
@@ -38,7 +38,9 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class BinaryFileFormatSuite extends QueryTest with SharedSparkSession {
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
+class BinaryFileFormatSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite {
import BinaryFileFormat._
private var testDir: String = _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
index cd6f41b4ef4..4b6a17344bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -28,7 +28,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.TestUtils
import org.apache.spark.memory.MemoryMode
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{IgnoreComet, Row}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -201,7 +201,8 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess
}
}
- test("parquet v2 pages - rle encoding for boolean value columns") {
+ test("parquet v2 pages - rle encoding for boolean value columns",
+ IgnoreComet("Comet doesn't support RLE encoding yet")) {
val extraOptions = Map[String, String](
ParquetOutputFormat.WRITER_VERSION -> ParquetProperties.WriterVersion.PARQUET_2_0.toString
)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 795e9f46a8d..5306c94a686 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -1100,7 +1100,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// When a filter is pushed to Parquet, Parquet can apply it to every row.
// So, we can check the number of rows returned from the Parquet
// to make sure our filter pushdown work.
- assert(stripSparkFilter(df).count() == 1)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ assert(stripSparkFilter(df).count() == 1)
+ }
}
}
}
@@ -1585,7 +1589,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// than the total length but should not be a single record.
// Note that, if record level filtering is enabled, it should be a single record.
// If no filter is pushed down to Parquet, it should be the total length of data.
- assert(actual > 1 && actual < data.length)
+ // Only enable Comet test iff it's scan only, since with native execution
+ // `stripSparkFilter` can't remove the native filter
+ if (!isCometEnabled || isCometScanOnly) {
+ assert(actual > 1 && actual < data.length)
+ }
}
}
}
@@ -1612,7 +1620,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// than the total length but should not be a single record.
// Note that, if record level filtering is enabled, it should be a single record.
// If no filter is pushed down to Parquet, it should be the total length of data.
- assert(actual > 1 && actual < data.length)
+ // Only enable Comet test iff it's scan only, since with native execution
+ // `stripSparkFilter` can't remove the native filter
+ if (!isCometEnabled || isCometScanOnly) {
+ assert(actual > 1 && actual < data.length)
+ }
}
}
}
@@ -1748,7 +1760,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("SPARK-17091: Convert IN predicate to Parquet filter push-down") {
+ test("SPARK-17091: Convert IN predicate to Parquet filter push-down",
+ IgnoreComet("IN predicate is not yet supported in Comet, see issue #36")) {
val schema = StructType(Seq(
StructField("a", IntegerType, nullable = false)
))
@@ -1991,7 +2004,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("Support Parquet column index") {
+ test("Support Parquet column index",
+ IgnoreComet("Comet doesn't support Parquet column index yet")) {
// block 1:
// null count min max
// page-0 0 0 99
@@ -2301,7 +2315,11 @@ class ParquetV1FilterSuite extends ParquetFilterSuite {
assert(pushedParquetFilters.exists(_.getClass === filterClass),
s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.")
- checker(stripSparkFilter(query), expected)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ checker(stripSparkFilter(query), expected)
+ }
} else {
assert(selectedFilters.isEmpty, "There is filter pushed down")
}
@@ -2362,7 +2380,11 @@ class ParquetV2FilterSuite extends ParquetFilterSuite {
assert(pushedParquetFilters.exists(_.getClass === filterClass),
s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.")
- checker(stripSparkFilter(query), expected)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ checker(stripSparkFilter(query), expected)
+ }
case _ => assert(false, "Can not match ParquetTable in the query.")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 4fb8faa43a3..984fd1a9892 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -1297,7 +1297,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
- test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings") {
+ test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings",
+ IgnoreComet("Comet doesn't support DELTA encoding yet")) {
withAllParquetReaders {
checkAnswer(
// "fruit" column in this file is encoded using DELTA_LENGTH_BYTE_ARRAY.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index a329d3fdc3c..d29523a41f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -1042,7 +1042,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
checkAnswer(readParquet(schema2, path), df)
}
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false",
+ "spark.comet.enabled" -> "false") {
val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
checkAnswer(readParquet(schema1, path), df)
val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
@@ -1066,7 +1067,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
val df = sql(s"SELECT 1 a, 123456 b, ${Int.MaxValue.toLong * 10} c, CAST('1.2' AS BINARY) d")
df.write.parquet(path.toString)
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false",
+ "spark.comet.enabled" -> "false") {
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
index 6d9092391a9..6da095120d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
@@ -21,7 +21,7 @@ import java.nio.file.{Files, Paths, StandardCopyOption}
import java.sql.{Date, Timestamp}
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
-import org.apache.spark.sql.{QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY}
+import org.apache.spark.sql.{IgnoreCometSuite, QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.LegacyBehaviorPolicy.{CORRECTED, EXCEPTION, LEGACY}
@@ -30,9 +30,11 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType.{INT96,
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.tags.SlowSQLTest
+// Comet is disabled for this suite because it doesn't support datetime rebase mode
abstract class ParquetRebaseDatetimeSuite
extends QueryTest
with ParquetTest
+ with IgnoreCometSuite
with SharedSparkSession {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
index 95378d94674..0c915fdc634 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
@@ -27,6 +27,7 @@ import org.apache.parquet.hadoop.ParquetWriter.DEFAULT_BLOCK_SIZE
import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -245,6 +246,12 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession {
case f: FileSourceScanExec =>
numPartitions += f.inputRDD.partitions.length
numOutputRows += f.metrics("numOutputRows").value
+ case b: CometScanExec =>
+ numPartitions += b.inputRDD.partitions.length
+ numOutputRows += b.metrics("numOutputRows").value
+ case b: CometBatchScanExec =>
+ numPartitions += b.inputRDD.partitions.length
+ numOutputRows += b.metrics("numOutputRows").value
case _ =>
}
assert(numPartitions > 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
index 5c0b7def039..151184bc98c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.comet.CometBatchScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -56,6 +57,7 @@ class ParquetV2SchemaPruningSuite extends ParquetSchemaPruningSuite {
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case scan: BatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema
+ case scan: CometBatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 25f6af1cc33..37b40cb5524 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -27,7 +27,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.Type._
import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
import org.apache.spark.sql.functions.desc
@@ -1037,7 +1037,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
e
}
- test("schema mismatch failure error message for parquet reader") {
+ test("schema mismatch failure error message for parquet reader",
+ IgnoreComet("Comet doesn't work with vectorizedReaderEnabled = false")) {
withTempPath { dir =>
val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false)
val expectedMessage = "Encountered error while reading file"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
index 4bd35e0789b..6544d86dbe0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
@@ -65,7 +65,9 @@ class ParquetTypeWideningSuite
withClue(
s"with dictionary encoding '$dictionaryEnabled' with timestamp rebase mode " +
s"'$timestampRebaseMode''") {
- withAllParquetWriters {
+ // TODO: Comet cannot read DELTA_BINARY_PACKED created by V2 writer
+ // https://github.com/apache/datafusion-comet/issues/574
+ // withAllParquetWriters {
withTempDir { dir =>
val expected =
writeParquetFiles(dir, values, fromType, dictionaryEnabled, timestampRebaseMode)
@@ -86,7 +88,7 @@ class ParquetTypeWideningSuite
}
}
}
- }
+ // }
}
}
@@ -190,7 +192,8 @@ class ParquetTypeWideningSuite
(Seq("1", "2", Short.MinValue.toString), ShortType, DoubleType),
(Seq("1", "2", Int.MinValue.toString), IntegerType, DoubleType),
(Seq("1.23", "10.34"), FloatType, DoubleType),
- (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType)
+ // TODO: Comet cannot handle older than "1582-10-15"
+ (Seq("2020-01-01", "2020-01-02"/* , "1312-02-27" */), DateType, TimestampNTZType)
)
}
test(s"parquet widening conversion $fromType -> $toType") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index b8f3ea3c6f3..bbd44221288 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.debug
import java.io.ByteArrayOutputStream
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.IgnoreComet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -125,7 +126,8 @@ class DebuggingSuite extends DebuggingSuiteBase with DisableAdaptiveExecutionSui
| id LongType: {}""".stripMargin))
}
- test("SPARK-28537: DebugExec cannot debug columnar related queries") {
+ test("SPARK-28537: DebugExec cannot debug columnar related queries",
+ IgnoreComet("Comet does not use FileScan")) {
withTempPath { workDir =>
val workDirPath = workDir.getAbsolutePath
val input = spark.range(5).toDF("id")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 45c775e6c46..41240792038 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -46,8 +46,10 @@ import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.{AccumulatorContext, JsonProtocol}
// Disable AQE because metric info is different with AQE on/off
+// This test suite runs tests against the metrics of physical operators.
+// Disabling it for Comet because the metrics are different with Comet enabled.
class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
- with DisableAdaptiveExecutionSuite {
+ with DisableAdaptiveExecutionSuite with IgnoreCometSuite {
import testImplicits._
/**
@@ -765,7 +767,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
}
}
- test("SPARK-26327: FileSourceScanExec metrics") {
+ test("SPARK-26327: FileSourceScanExec metrics",
+ IgnoreComet("Spark uses row-based Parquet reader while Comet is vectorized")) {
withTable("testDataForScan") {
spark.range(10).selectExpr("id", "id % 3 as p")
.write.partitionBy("p").saveAsTable("testDataForScan")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
index 0ab8691801d..d9125f658ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.sql.catalyst.plans.logical.{ArrowEvalPython, BatchEvalPython, Limit, LocalLimit}
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
@@ -108,6 +109,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
+ case scan: CometScanExec => scan
}
assert(scanNodes.length == 1)
assert(scanNodes.head.output.map(_.name) == Seq("a"))
@@ -120,11 +122,16 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
+ case scan: CometScanExec => scan
}
assert(scanNodes.length == 1)
// $"a" is not null and $"a" > 1
- assert(scanNodes.head.dataFilters.length == 2)
- assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
+ val dataFilters = scanNodes.head match {
+ case scan: FileSourceScanExec => scan.dataFilters
+ case scan: CometScanExec => scan.dataFilters
+ }
+ assert(dataFilters.length == 2)
+ assert(dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
}
}
}
@@ -145,6 +152,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: BatchScanExec => scan
+ case scan: CometBatchScanExec => scan
}
assert(scanNodes.length == 1)
assert(scanNodes.head.output.map(_.name) == Seq("a"))
@@ -157,6 +165,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: BatchScanExec => scan
+ case scan: CometBatchScanExec => scan
}
assert(scanNodes.length == 1)
// $"a" is not null and $"a" > 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
index 6ff07449c0c..9f95cff99e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
@@ -37,8 +37,10 @@ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException,
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.Utils
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
class AsyncProgressTrackingMicroBatchExecutionSuite
- extends StreamTest with BeforeAndAfter with Matchers {
+ extends StreamTest with BeforeAndAfter with Matchers with IgnoreCometSuite {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 3573bafe482..11d387110ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
import org.apache.spark.sql.execution.datasources.BucketingUtils
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -102,12 +103,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
}
- private def getFileScan(plan: SparkPlan): FileSourceScanExec = {
- val fileScan = collect(plan) { case f: FileSourceScanExec => f }
+ private def getFileScan(plan: SparkPlan): SparkPlan = {
+ val fileScan = collect(plan) {
+ case f: FileSourceScanExec => f
+ case f: CometScanExec => f
+ }
assert(fileScan.nonEmpty, plan)
fileScan.head
}
+ private def getBucketScan(plan: SparkPlan): Boolean = getFileScan(plan) match {
+ case fs: FileSourceScanExec => fs.bucketedScan
+ case bs: CometScanExec => bs.bucketedScan
+ }
+
// To verify if the bucket pruning works, this function checks two conditions:
// 1) Check if the pruned buckets (before filtering) are empty.
// 2) Verify the final result is the same as the expected one
@@ -156,7 +165,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition)
.queryExecution.executedPlan
val fileScan = getFileScan(planWithoutBucketedScan)
- assert(!fileScan.bucketedScan, s"except no bucketed scan but found\n$fileScan")
+ val bucketedScan = getBucketScan(planWithoutBucketedScan)
+ assert(!bucketedScan, s"except no bucketed scan but found\n$fileScan")
val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -452,28 +462,49 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case b: CometSortMergeJoinExec =>
+ b.originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
} else {
val executedPlan = joined.queryExecution.executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case ColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}
// check existence of shuffle
assert(
- joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft,
+ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft,
s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}")
assert(
- joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight,
+ joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight,
s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}")
// check existence of sort
assert(
- joinOperator.left.exists(_.isInstanceOf[SortExec]) == sortLeft,
+ joinOperator.left.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] &&
+ op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortLeft,
s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}")
assert(
- joinOperator.right.exists(_.isInstanceOf[SortExec]) == sortRight,
+ joinOperator.right.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] &&
+ op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortRight,
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")
// check the output partitioning
@@ -836,11 +867,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
val scanDF = spark.table("bucketed_table").select("j")
- assert(!getFileScan(scanDF.queryExecution.executedPlan).bucketedScan)
+ assert(!getBucketScan(scanDF.queryExecution.executedPlan))
checkAnswer(scanDF, df1.select("j"))
val aggDF = spark.table("bucketed_table").groupBy("j").agg(max("k"))
- assert(!getFileScan(aggDF.queryExecution.executedPlan).bucketedScan)
+ assert(!getBucketScan(aggDF.queryExecution.executedPlan))
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1029,15 +1060,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
Seq(true, false).foreach { aqeEnabled =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
val plan = sql(query).queryExecution.executedPlan
- val shuffles = collect(plan) { case s: ShuffleExchangeExec => s }
+ val shuffles = collect(plan) { case s: ShuffleExchangeLike => s }
assert(shuffles.length == expectedNumShuffles)
val scans = collect(plan) {
case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f
+ case b: CometScanExec if b.optionalNumCoalescedBuckets.isDefined => b
}
if (expectedCoalescedNumBuckets.isDefined) {
assert(scans.length == 1)
- assert(scans.head.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ scans.head match {
+ case f: FileSourceScanExec =>
+ assert(f.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ case b: CometScanExec =>
+ assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ }
} else {
assert(scans.isEmpty)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 6f897a9c0b7..b0723634f68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import java.io.File
import org.apache.spark.SparkException
+import org.apache.spark.sql.IgnoreCometSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTableType}
import org.apache.spark.sql.catalyst.parser.ParseException
@@ -27,7 +28,10 @@ import org.apache.spark.sql.internal.SQLConf.BUCKETING_MAX_BUCKETS
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
-class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession {
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
+class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession
+ with IgnoreCometSuite {
import testImplicits._
protected override lazy val sql = spark.sql _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
index c5c56f081d8..197cd241f48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.sources
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
@@ -68,7 +69,10 @@ abstract class DisableUnnecessaryBucketedScanSuite
def checkNumBucketedScan(query: String, expectedNumBucketedScan: Int): Unit = {
val plan = sql(query).queryExecution.executedPlan
- val bucketedScan = collect(plan) { case s: FileSourceScanExec if s.bucketedScan => s }
+ val bucketedScan = collect(plan) {
+ case s: FileSourceScanExec if s.bucketedScan => s
+ case s: CometScanExec if s.bucketedScan => s
+ }
assert(bucketedScan.length == expectedNumBucketedScan)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 04193d5189a..d83d03f8e0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.catalyst.util.stringToFile
+import org.apache.spark.sql.comet.CometBatchScanExec
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, FileScan, FileTable}
@@ -749,6 +750,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
val fileScan = df.queryExecution.executedPlan.collect {
case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] =>
batch.scan.asInstanceOf[FileScan]
+ case batch: CometBatchScanExec if batch.scan.isInstanceOf[FileScan] =>
+ batch.scan.asInstanceOf[FileScan]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 1fce992126b..6d3ea74e0fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpressi
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.comet.CometLocalLimitExec
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
@@ -1116,11 +1117,12 @@ class StreamSuite extends StreamTest {
val localLimits = execPlan.collect {
case l: LocalLimitExec => l
case l: StreamingLocalLimitExec => l
+ case l: CometLocalLimitExec => l
}
require(
localLimits.size == 1,
- s"Cant verify local limit optimization with this plan:\n$execPlan")
+ s"Cant verify local limit optimization ${localLimits.size} with this plan:\n$execPlan")
if (expectStreamingLimit) {
assert(
@@ -1128,7 +1130,8 @@ class StreamSuite extends StreamTest {
s"Local limit was not StreamingLocalLimitExec:\n$execPlan")
} else {
assert(
- localLimits.head.isInstanceOf[LocalLimitExec],
+ localLimits.head.isInstanceOf[LocalLimitExec] ||
+ localLimits.head.isInstanceOf[CometLocalLimitExec],
s"Local limit was not LocalLimitExec:\n$execPlan")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
index b4c4ec7acbf..20579284856 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils
import org.scalatest.Assertions
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec}
import org.apache.spark.sql.functions.count
@@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = query.lastExecution.executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}
val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
@@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}
val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index e05cb4d3c35..dc65a4fe18e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper}
import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId}
import org.apache.spark.sql.functions._
@@ -620,14 +620,28 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
val numPartitions = spark.sessionState.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)
- assert(query.lastExecution.executedPlan.collect {
- case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _,
- ShuffleExchangeExec(opA: HashPartitioning, _, _, _),
- ShuffleExchangeExec(opB: HashPartitioning, _, _, _))
- if partitionExpressionsColumns(opA.expressions) === Seq("a", "b")
- && partitionExpressionsColumns(opB.expressions) === Seq("a", "b")
- && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j
- }.size == 1)
+ val join = query.lastExecution.executedPlan.collect {
+ case j: StreamingSymmetricHashJoinExec => j
+ }.head
+ val opA = join.left.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning]
+ }.head
+ val opB = join.right.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning]
+ }.head
+ assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions)
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index af07aceaed1..ed0b5e6d9be 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -22,7 +22,7 @@ import java.util
import org.scalatest.BeforeAndAfter
-import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
@@ -334,7 +334,8 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter {
}
}
- test("explain with table on DSv1 data source") {
+ test("explain with table on DSv1 data source",
+ IgnoreComet("Comet explain output is different")) {
val tblSourceName = "tbl_src"
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 5fbf379644f..47e0f4a2c9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.execution.datasources.DataSourceUtils
@@ -127,7 +128,11 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with
}
}
} else {
- super.test(testName, testTags: _*)(testFun)
+ if (isCometEnabled && testTags.exists(_.isInstanceOf[IgnoreComet])) {
+ ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
+ } else {
+ super.test(testName, testTags: _*)(testFun)
+ }
}
}
@@ -243,6 +248,29 @@ private[sql] trait SQLTestUtilsBase
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
+ /**
+ * Whether Comet extension is enabled
+ */
+ protected def isCometEnabled: Boolean = SparkSession.isCometEnabled
+
+ /**
+ * Whether to enable ansi mode This is only effective when
+ * [[isCometEnabled]] returns true.
+ */
+ protected def enableCometAnsiMode: Boolean = {
+ val v = System.getenv("ENABLE_COMET_ANSI_MODE")
+ v != null && v.toBoolean
+ }
+
+ /**
+ * Whether Spark should only apply Comet scan optimization. This is only effective when
+ * [[isCometEnabled]] returns true.
+ */
+ protected def isCometScanOnly: Boolean = {
+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
+ v != null && v.toBoolean
+ }
+
protected override def withSQLConf[T](pairs: (String, String)*)(f: => T): T = {
SparkSession.setActiveSession(spark)
super.withSQLConf(pairs: _*)(f)
@@ -434,6 +462,8 @@ private[sql] trait SQLTestUtilsBase
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}
spark.internalCreateDataFrame(withoutFilters.execute(), schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index ed2e309fa07..71ba6533c9d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -74,6 +74,31 @@ trait SharedSparkSessionBase
// this rule may potentially block testing of other optimization rules such as
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+ // Enable Comet if `ENABLE_COMET` environment variable is set
+ if (isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
+
+ if (!isCometScanOnly) {
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
+ .set("spark.comet.memoryOverhead", "10g")
+ } else {
+ conf
+ .set("spark.comet.exec.enabled", "false")
+ .set("spark.comet.exec.shuffle.enabled", "false")
+ }
+
+ if (enableCometAnsiMode) {
+ conf
+ .set("spark.sql.ansi.enabled", "true")
+ .set("spark.comet.ansi.enabled", "true")
+ }
+ }
conf.set(
StaticSQLConf.WAREHOUSE_PATH,
conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName)
diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
index 1b7909534a0..45b90ef10a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
@@ -46,7 +46,7 @@ class SqlResourceWithActualMetricsSuite
import testImplicits._
// Exclude nodes which may not have the metrics
- val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject")
+ val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject", "RowToColumnar")
implicit val formats: DefaultFormats = new DefaultFormats {
override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
index 52abd248f3a..7a199931a08 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.hive.execution.HiveTableScanExec
@@ -35,6 +36,9 @@ abstract class DynamicPartitionPruningHiveScanSuiteBase
case s: FileSourceScanExec => s.partitionFilters.collect {
case d: DynamicPruningExpression => d.child
}
+ case s: CometScanExec => s.partitionFilters.collect {
+ case d: DynamicPruningExpression => d.child
+ }
case h: HiveTableScanExec => h.partitionPruningPred.collect {
case d: DynamicPruningExpression => d.child
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 3f8de93b330..53417076481 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -53,24 +53,55 @@ object TestHive
new SparkContext(
System.getProperty("spark.sql.test.master", "local[1]"),
"TestSQLContext",
- new SparkConf()
- .set("spark.sql.test", "")
- .set(SQLConf.CODEGEN_FALLBACK.key, "false")
- .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
- .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
- "org.apache.spark.sql.hive.execution.PairSerDe")
- .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath)
- // SPARK-8910
- .set(UI_ENABLED, false)
- .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
- // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes
- // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764.
- .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false")
- // Disable ConvertToLocalRelation for better test coverage. Test cases built on
- // LocalRelation will exercise the optimization rules better by disabling it as
- // this rule may potentially block testing of other optimization rules such as
- // ConstantPropagation etc.
- .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName))) {
+ {
+ val conf = new SparkConf()
+ .set("spark.sql.test", "")
+ .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+ .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
+ .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
+ "org.apache.spark.sql.hive.execution.PairSerDe")
+ .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath)
+ // SPARK-8910
+ .set(UI_ENABLED, false)
+ .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
+ // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes
+ // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764.
+ .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false")
+ // Disable ConvertToLocalRelation for better test coverage. Test cases built on
+ // LocalRelation will exercise the optimization rules better by disabling it as
+ // this rule may potentially block testing of other optimization rules such as
+ // ConstantPropagation etc.
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+
+ if (SparkSession.isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
+
+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
+ if (v == null || !v.toBoolean) {
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
+ } else {
+ conf
+ .set("spark.comet.exec.enabled", "false")
+ .set("spark.comet.exec.shuffle.enabled", "false")
+ }
+
+ val a = System.getenv("ENABLE_COMET_ANSI_MODE")
+ if (a != null && a.toBoolean) {
+ conf
+ .set("spark.sql.ansi.enabled", "true")
+ .set("spark.comet.ansi.enabled", "true")
+ }
+ }
+
+ conf
+ }
+ )) {
override def conf: SQLConf = sparkSession.sessionState.conf
}