blob: 3c81e01a3c52fe07ce0e0cebe491d8ab10dbb2dd [file]
diff --git a/pom.xml b/pom.xml
index edd2ad57880..77a975ea48f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -152,6 +152,8 @@
-->
<ivy.version>2.5.1</ivy.version>
<oro.version>2.0.8</oro.version>
+ <spark.version.short>3.5</spark.version.short>
+ <comet.version>0.14.0-SNAPSHOT</comet.version>
<!--
If you changes codahale.metrics.version, you also need to change
the link to metrics.dropwizard.io in docs/monitoring.md.
@@ -2840,6 +2842,25 @@
<artifactId>okio</artifactId>
<version>${okio.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.datafusion</groupId>
+ <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
+ <version>${comet.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
</dependencyManagement>
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index bc00c448b80..82068d7a2eb 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -77,6 +77,10 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.datafusion</groupId>
+ <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
+ </dependency>
<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 27ae10b3d59..78e69902dfd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -1353,6 +1353,14 @@ object SparkSession extends Logging {
}
}
+ private def loadCometExtension(sparkContext: SparkContext): Seq[String] = {
+ if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled)) {
+ Seq("org.apache.comet.CometSparkSessionExtensions")
+ } else {
+ Seq.empty
+ }
+ }
+
/**
* Initialize extensions specified in [[StaticSQLConf]]. The classes will be applied to the
* extensions passed into this function.
@@ -1362,6 +1370,7 @@ object SparkSession extends Logging {
extensions: SparkSessionExtensions): SparkSessionExtensions = {
val extensionConfClassNames = sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
.getOrElse(Seq.empty)
+ val extensionClassNames = extensionConfClassNames ++ loadCometExtension(sparkContext)
extensionConfClassNames.foreach { extensionConfClassName =>
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
@@ -1396,4 +1405,12 @@ object SparkSession extends Logging {
}
}
}
+
+ /**
+ * Whether Comet extension is enabled
+ */
+ def isCometEnabled: Boolean = {
+ val v = System.getenv("ENABLE_COMET")
+ v == null || v.toBoolean
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index db587dd9868..aac7295a53d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
@@ -67,6 +68,7 @@ private[execution] object SparkPlanInfo {
// dump the file scan metadata (e.g file path) to event log
val metadata = plan match {
case fileScan: FileSourceScanExec => fileScan.metadata
+ case cometScan: CometScanExec => cometScan.metadata
case _ => Map[String, String]()
}
new SparkPlanInfo(
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
index 7aef901da4f..f3d6e18926d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
@@ -2,3 +2,4 @@
--SET spark.sql.adaptive.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
index eeb2180f7a5..afd1b5ec289 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
@@ -1,5 +1,6 @@
--SET spark.sql.cbo.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
CREATE TABLE explain_temp1(a INT, b INT) USING PARQUET;
CREATE TABLE explain_temp2(c INT, d INT) USING PARQUET;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql
index 698ca009b4f..57d774a3617 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql
@@ -1,6 +1,7 @@
--SET spark.sql.codegen.wholeStage = true
--SET spark.sql.adaptive.enabled = false
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
-- Test tables
CREATE table explain_temp1 (key int, val int) USING PARQUET;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql
index 1152d77da0c..f77493f690b 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql
@@ -7,6 +7,9 @@
-- avoid bit-exact output here because operations may not be bit-exact.
-- SET extra_float_digits = 0;
+-- Disable Comet exec due to floating point precision difference
+--SET spark.comet.exec.enabled = false
+
-- Test aggregate operator with codegen on and off.
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
index 41fd4de2a09..44cd244d3b0 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
@@ -5,6 +5,9 @@
-- AGGREGATES [Part 3]
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L352-L605
+-- Disable Comet exec due to floating point precision difference
+--SET spark.comet.exec.enabled = false
+
-- Test aggregate operator with codegen on and off.
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
index 3a409eea348..38fed024c98 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1
-- any evens
SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0');
+-- https://github.com/apache/datafusion-comet/issues/2215
+--SET spark.comet.exec.enabled=false
-- [SPARK-28024] Incorrect value when out of range
SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
index fac23b4a26f..2b73732c33f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -1,6 +1,10 @@
--
-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
--
+
+-- Disable Comet exec due to floating point precision difference
+--SET spark.comet.exec.enabled = false
+
--
-- INT8
-- Test int8 64-bit integers.
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
index 0efe0877e9b..423d3b3d76d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
@@ -1,6 +1,10 @@
--
-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
--
+
+-- Disable Comet exec due to floating point precision difference
+--SET spark.comet.exec.enabled = false
+
--
-- SELECT_HAVING
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index e5494726695..00937f025c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation}
import org.apache.spark.sql.execution.columnar._
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -519,7 +519,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
df.collect()
}
assert(
- collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.size == expected)
+ collect(df.queryExecution.executedPlan) {
+ case _: ShuffleExchangeLike => 1 }.size == expected)
}
test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 9e8d77c53f3..855e3ada7d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -790,7 +790,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
}
}
- test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") {
+ test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) {
withTempPath { dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 6f3090d8908..c08a60fb0c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -793,7 +793,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)
val exchangePlans = collect(aggPlan) {
- case shuffle: ShuffleExchangeExec => shuffle
+ case shuffle: ShuffleExchangeLike => shuffle
}
assert(exchangePlans.length == 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 56e9520fdab..917932336df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -435,7 +435,9 @@ class DataFrameJoinSuite extends QueryTest
withTempDatabase { dbName =>
withTable(table1Name, table2Name) {
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ "spark.comet.enabled" -> "false") {
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 7ee18df3756..d09f70e5d99 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -40,11 +40,12 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.comet.CometBroadcastExchangeExec
import org.apache.spark.sql.connector.FakeV2Provider
import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -2006,7 +2007,7 @@ class DataFrameSuite extends QueryTest
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
- case e: ShuffleExchangeExec => atFirstAgg = false
+ case e: ShuffleExchangeLike => atFirstAgg = false
case _ =>
}
}
@@ -2330,7 +2331,7 @@ class DataFrameSuite extends QueryTest
checkAnswer(join, df)
assert(
collect(join.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => true }.size === 1)
+ case _: ShuffleExchangeLike => true }.size === 1)
assert(
collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1)
val broadcasted = broadcast(join)
@@ -2338,10 +2339,12 @@ class DataFrameSuite extends QueryTest
checkAnswer(join2, df)
assert(
collect(join2.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => true }.size == 1)
+ case _: ShuffleExchangeLike => true }.size == 1)
assert(
collect(join2.queryExecution.executedPlan) {
- case e: BroadcastExchangeExec => true }.size === 1)
+ case e: BroadcastExchangeExec => true
+ case _: CometBroadcastExchangeExec => true
+ }.size === 1)
assert(
collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4)
}
@@ -2901,7 +2904,7 @@ class DataFrameSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
- case h: ShuffleExchangeExec => h
+ case h: ShuffleExchangeLike => h
}
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index a1d5d579338..c201d39cc78 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
import org.apache.spark.sql.catalyst.plans.logical.{Window => LogicalWindow}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
@@ -1187,10 +1188,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}
def isShuffleExecByRequirement(
- plan: ShuffleExchangeExec,
+ plan: ShuffleExchangeLike,
desiredClusterColumns: Seq[String]): Boolean = plan match {
case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS, _) =>
partitionExpressionsColumns(op.expressions) === desiredClusterColumns
+ case CometShuffleExchangeExec(op: HashPartitioning, _, _, ENSURE_REQUIREMENTS, _, _) =>
+ partitionExpressionsColumns(op.expressions) === desiredClusterColumns
case _ => false
}
@@ -1213,7 +1216,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest
val shuffleByRequirement = windowed.queryExecution.executedPlan.exists {
case w: WindowExec =>
w.child.exists {
- case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2"))
+ case s: ShuffleExchangeLike => isShuffleExecByRequirement(s, Seq("key1", "key2"))
case _ => false
}
case _ => false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c4fb4fa943c..a04b23870a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
@@ -2288,7 +2288,7 @@ class DatasetSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
- case h: ShuffleExchangeExec => h
+ case h: ShuffleExchangeLike => h
}
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index f33432ddb6f..42eb9fd1cb7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
@@ -262,6 +263,9 @@ abstract class DynamicPartitionPruningSuiteBase
case s: BatchScanExec => s.runtimeFilters.collect {
case d: DynamicPruningExpression => d.child
}
+ case s: CometScanExec => s.partitionFilters.collect {
+ case d: DynamicPruningExpression => d.child
+ }
case _ => Nil
}
}
@@ -1027,7 +1031,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("avoid reordering broadcast join keys to match input hash partitioning") {
+ test("avoid reordering broadcast join keys to match input hash partitioning",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTable("large", "dimTwo", "dimThree") {
@@ -1215,7 +1220,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " +
- "canonicalization and exchange reuse") {
+ "canonicalization and exchange reuse",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(
@@ -1423,7 +1429,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-34637: DPP side broadcast query stage is created firstly") {
+ test("SPARK-34637: DPP side broadcast query stage is created firstly",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
""" WITH v as (
@@ -1698,7 +1705,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
* Check the static scan metrics with and without DPP
*/
test("static scan metrics",
- DisableAdaptiveExecution("DPP in AQE must reuse broadcast")) {
+ DisableAdaptiveExecution("DPP in AQE must reuse broadcast"),
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3313")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") {
@@ -1729,6 +1737,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
+ case s: CometScanExec =>
+ s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid")))
case _ => false
}
assert(scanOption.isDefined)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index a206e97c353..79813d8e259 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -280,7 +280,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}
- test("explain formatted - check presence of subquery in case of DPP") {
+ test("explain formatted - check presence of subquery in case of DPP",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3313")) {
withTable("df1", "df2") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
@@ -467,7 +468,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}
- test("Explain formatted output for scan operator for datasource V2") {
+ test("Explain formatted output for scan operator for datasource V2",
+ IgnoreComet("Comet explain output is different")) {
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
@@ -545,7 +547,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}
-class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite {
+// Ignored when Comet is enabled. Comet changes expected query plans.
+class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite
+ with IgnoreCometSuite {
import testImplicits._
test("SPARK-35884: Explain Formatted") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 93275487f29..510e3087e0f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -23,6 +23,7 @@ import java.nio.file.{Files, StandardOpenOption}
import scala.collection.mutable
+import org.apache.comet.CometConf
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{LocalFileSystem, Path}
@@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeScanExec, CometScanExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
@@ -250,6 +252,8 @@ class FileBasedDataSourceSuite extends QueryTest
case "" => "_LEGACY_ERROR_TEMP_2062"
case _ => "_LEGACY_ERROR_TEMP_2055"
}
+ // native_datafusion Parquet scan cannot throw a SparkFileNotFoundException
+ assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
checkErrorMatchPVals(
exception = intercept[SparkException] {
testIgnoreMissingFiles(options)
@@ -639,7 +643,8 @@ class FileBasedDataSourceSuite extends QueryTest
}
Seq("parquet", "orc").foreach { format =>
- test(s"Spark native readers should respect spark.sql.caseSensitive - ${format}") {
+ test(s"Spark native readers should respect spark.sql.caseSensitive - ${format}",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
withTempDir { dir =>
val tableName = s"spark_25132_${format}_native"
val tableDir = dir.getCanonicalPath + s"/$tableName"
@@ -955,6 +960,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(bJoinExec.isEmpty)
val smJoinExec = collect(joinedDF.queryExecution.executedPlan) {
case smJoin: SortMergeJoinExec => smJoin
+ case smJoin: CometSortMergeJoinExec => smJoin
}
assert(smJoinExec.nonEmpty)
}
@@ -1015,6 +1021,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
@@ -1056,6 +1063,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
@@ -1240,6 +1248,9 @@ class FileBasedDataSourceSuite extends QueryTest
val filters = df.queryExecution.executedPlan.collect {
case f: FileSourceScanLike => f.dataFilters
case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
+ case b: CometScanExec => b.dataFilters
+ case b: CometNativeScanExec => b.dataFilters
+ case b: CometBatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
}.flatten
assert(filters.contains(GreaterThan(scan.logicalPlan.output.head, Literal(5L))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
new file mode 100644
index 00000000000..1ee842b6f62
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalactic.source.Position
+import org.scalatest.Tag
+
+import org.apache.spark.sql.test.SQLTestUtils
+
+/**
+ * Tests with this tag will be ignored when Comet is enabled (e.g., via `ENABLE_COMET`).
+ */
+case class IgnoreComet(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeIcebergCompat(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeDataFusion(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeScan(reason: String) extends Tag("DisableComet")
+
+/**
+ * Helper trait that disables Comet for all tests regardless of default config values.
+ */
+trait IgnoreCometSuite extends SQLTestUtils {
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ if (isCometEnabled) {
+ ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
+ } else {
+ super.test(testName, testTags: _*)(testFun)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 7af826583bd..3c3def1eb67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
@@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleHashJoins = collect(executedPlan) {
case s: ShuffledHashJoinExec => s
+ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec]
}
assert(shuffleHashJoins.size == 1)
assert(shuffleHashJoins.head.buildSide == buildSide)
@@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleMergeJoins = collect(executedPlan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c
}
assert(shuffleMergeJoins.size == 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 44c8cb92fc3..f098beeca26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION}
-import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, FilterExec, InputAdapter, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins._
@@ -802,7 +803,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
}
- test("test SortMergeJoin (with spill)") {
+ test("test SortMergeJoin (with spill)",
+ IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
@@ -928,10 +930,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val physical = df.queryExecution.sparkPlan
val physicalJoins = physical.collect {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
val executed = df.queryExecution.executedPlan
val executedJoins = collect(executed) {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
// This only applies to the above tested queries, in which a child SortMergeJoin always
// contains the SortOrder required by its parent SortMergeJoin. Thus, SortExec should never
@@ -1177,9 +1181,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
.groupBy($"k1").count()
.queryExecution.executedPlan
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// No extra shuffle before aggregate
- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true }.size === 2)
})
}
@@ -1196,10 +1202,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2)
assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1)
// No extra sort before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 3)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 3)
})
// Test shuffled hash join
@@ -1209,10 +1216,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// No extra sort before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortExec | _: CometSortExec => true }.size === 3)
})
}
@@ -1303,12 +1313,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
inputDFs.foreach { case (df1, df2, joinExprs) =>
val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full")
assert(collect(smjDF.queryExecution.executedPlan) {
- case _: SortMergeJoinExec => true }.size === 1)
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 1)
val smjResult = smjDF.collect()
val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full")
assert(collect(shjDF.queryExecution.executedPlan) {
- case _: ShuffledHashJoinExec => true }.size === 1)
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
}
@@ -1367,12 +1377,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()
val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
@@ -1383,12 +1395,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()
val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
@@ -1432,13 +1446,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
assert(shjCodegenDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometHashJoinExec))) =>
+ true
+ case WholeStageCodegenExec(ColumnarToRowExec(
+ InputAdapter(CometProjectExec(_, _, _, _, _: CometHashJoinExec, _)))) => true
+ case _: CometHashJoinExec => true
}.size === 1)
checkAnswer(shjCodegenDF, Seq.empty)
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val shjNonCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
assert(shjNonCodegenDF.queryExecution.executedPlan.collect {
- case _: ShuffledHashJoinExec => true }.size === 1)
+ case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
+ }.size === 1)
checkAnswer(shjNonCodegenDF, Seq.empty)
}
}
@@ -1486,7 +1507,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
// Have shuffle before aggregation
- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true }.size === 1)
}
def getJoinQuery(selectExpr: String, joinType: String): String = {
@@ -1515,9 +1537,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// No extra sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 5)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 5)
}
// Test output ordering is not preserved
@@ -1526,9 +1551,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// Have sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 6)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 6)
}
// Test singe partition
@@ -1538,7 +1566,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
|FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
|""".stripMargin)
val plan = fullJoinDF.queryExecution.executedPlan
- assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true}.size == 1)
checkAnswer(fullJoinDF, Row(100))
}
}
@@ -1611,6 +1640,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
Seq(semiJoinDF, antiJoinDF).foreach { df =>
assert(collect(df.queryExecution.executedPlan) {
case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true
+ case j: CometHashJoinExec
+ if j.originalPlan.asInstanceOf[ShuffledHashJoinExec].ignoreDuplicatedKey ==
+ ignoreDuplicatedKey => true
}.size == 1)
}
}
@@ -1655,14 +1687,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") {
def check(plan: SparkPlan): Unit = {
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 1)
}
dupStreamSideColTest("MERGE", check)
}
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SHJ)") {
def check(plan: SparkPlan): Unit = {
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
+ }.size === 1)
}
dupStreamSideColTest("SHUFFLE_HASH", check)
}
@@ -1798,7 +1836,8 @@ class ThreadLeakInSortMergeJoinSuite
sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20))
}
- test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)") {
+ test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)",
+ IgnoreComet("Comet does not support spilling")) {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
index c26757c9cff..d55775f09d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
@@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest
* }}}
*/
// scalastyle:on line.size.limit
-trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite {
+trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite with IgnoreCometSuite {
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 3cf2bfd17ab..49728c35c42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
}
- test("external sorting updates peak execution memory") {
+ test("external sorting updates peak execution memory",
+ IgnoreComet("TODO: native CometSort does not update peak execution memory")) {
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
@@ -4459,7 +4460,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" +
- " when WSCG is off") {
+ " when WSCG is off",
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
@@ -4480,7 +4482,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-39175: Query context of Cast should be serialized to executors" +
- " when WSCG is off") {
+ " when WSCG is off",
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
@@ -4497,14 +4500,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
val msg = intercept[SparkException] {
sql(query).collect()
}.getMessage
- assert(msg.contains(query))
+ if (!isCometEnabled) {
+ // Comet's error message does not include the original SQL query
+ // https://github.com/apache/datafusion-comet/issues/2215
+ assert(msg.contains(query))
+ }
}
}
}
}
test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +
- "be serialized to executors when WSCG is off") {
+ "be serialized to executors when WSCG is off",
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index fa1a64460fc..1d2e215d6a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import org.apache.comet.CometConf
+
import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.execution.FormattedMode
@@ -178,29 +180,31 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
}
test("string regex_replace / regex_extract") {
- val df = Seq(
- ("100-200", "(\\d+)-(\\d+)", "300"),
- ("100-200", "(\\d+)-(\\d+)", "400"),
- ("100-200", "(\\d+)", "400")).toDF("a", "b", "c")
+ withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") {
+ val df = Seq(
+ ("100-200", "(\\d+)-(\\d+)", "300"),
+ ("100-200", "(\\d+)-(\\d+)", "400"),
+ ("100-200", "(\\d+)", "400")).toDF("a", "b", "c")
- checkAnswer(
- df.select(
- regexp_replace($"a", "(\\d+)", "num"),
- regexp_replace($"a", $"b", $"c"),
- regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
- Row("num-num", "300", "100") :: Row("num-num", "400", "100") ::
- Row("num-num", "400-400", "100") :: Nil)
-
- // for testing the mutable state of the expression in code gen.
- // This is a hack way to enable the codegen, thus the codegen is enable by default,
- // it will still use the interpretProjection if projection followed by a LocalRelation,
- // hence we add a filter operator.
- // See the optimizer rule `ConvertToLocalRelation`
- checkAnswer(
- df.filter("isnotnull(a)").selectExpr(
- "regexp_replace(a, b, c)",
- "regexp_extract(a, b, 1)"),
- Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
+ checkAnswer(
+ df.select(
+ regexp_replace($"a", "(\\d+)", "num"),
+ regexp_replace($"a", $"b", $"c"),
+ regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
+ Row("num-num", "300", "100") :: Row("num-num", "400", "100") ::
+ Row("num-num", "400-400", "100") :: Nil)
+
+ // for testing the mutable state of the expression in code gen.
+ // This is a hack way to enable the codegen, thus the codegen is enable by default,
+ // it will still use the interpretProjection if projection followed by a LocalRelation,
+ // hence we add a filter operator.
+ // See the optimizer rule `ConvertToLocalRelation`
+ checkAnswer(
+ df.filter("isnotnull(a)").selectExpr(
+ "regexp_replace(a, b, c)",
+ "regexp_extract(a, b, 1)"),
+ Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
+ }
}
test("non-matching optional group") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 04702201f82..5ee11f83ecf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort, Union}
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.FileScanRDD
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -1599,6 +1600,12 @@ class SubquerySuite extends QueryTest
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
_.files.forall(_.urlEncodedPath.contains("p=0"))))
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(
+ fs @ CometScanExec(_, _, _, _, partitionFilters, _, _, _, _, _, _)))) =>
+ partitionFilters.exists(ExecSubqueryExpression.hasSubquery) &&
+ fs.inputRDDs().forall(
+ _.asInstanceOf[FileScanRDD].filePartitions.forall(
+ _.files.forall(_.urlEncodedPath.contains("p=0"))))
case _ => false
})
}
@@ -2164,7 +2171,7 @@ class SubquerySuite extends QueryTest
df.collect()
val exchanges = collect(df.queryExecution.executedPlan) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(exchanges.size === 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 9f8e979e3fb..3bc9dab8023 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -87,7 +87,8 @@ class UDFSuite extends QueryTest with SharedSparkSession {
spark.catalog.dropTempView("tmp_table")
}
- test("SPARK-8005 input_file_name") {
+ test("SPARK-8005 input_file_name",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) {
withTempPath { dir =>
val data = sparkContext.parallelize(0 to 10, 2).toDF("id")
data.write.parquet(dir.getCanonicalPath)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index d269290e616..13726a31e07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
@@ -34,7 +35,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
import org.apache.spark.sql.execution.SortExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
-import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
@@ -269,13 +270,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val groupByColJ = df.groupBy($"j").agg(sum($"i"))
checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
assert(collectFirst(groupByColJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined)
val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*"))
checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined)
}
}
@@ -335,10 +336,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val (shuffleExpected, sortExpected) = groupByExpects
assert(collectFirst(groupBy.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined === shuffleExpected)
assert(collectFirst(groupBy.queryExecution.executedPlan) {
case e: SortExec => e
+ case c: CometSortExec => c
}.isDefined === sortExpected)
}
@@ -353,10 +355,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val (shuffleExpected, sortExpected) = windowFuncExpects
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined === shuffleExpected)
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
case e: SortExec => e
+ case c: CometSortExec => c
}.isDefined === sortExpected)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
index cfc8b2cc845..b7c234e1437 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.connector
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.{AnalysisException, IgnoreCometNativeDataFusion, QueryTest}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
@@ -152,7 +153,8 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
}
}
- test("Fallback Parquet V2 to V1") {
+ test("Fallback Parquet V2 to V1",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3315")) {
Seq("parquet", classOf[ParquetDataSourceV2].getCanonicalName).foreach { format =>
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) {
val commands = ArrayBuffer.empty[(String, LogicalPlan)]
@@ -184,7 +186,11 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
val df = spark.read.format(format).load(path.getCanonicalPath)
checkAnswer(df, inputData.toDF())
assert(
- df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec]))
+ df.queryExecution.executedPlan.exists {
+ case _: FileSourceScanExec | _: CometScanExec | _: CometNativeScanExec => true
+ case _ => false
+ }
+ )
}
} finally {
spark.listenerManager.unregister(listener)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 71e030f535e..d5ae6cbf3d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.comet.CometSortMergeJoinExec
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
import org.apache.spark.sql.connector.catalog.functions._
@@ -31,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
@@ -282,13 +283,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50)))
}
- private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
+ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = {
// here we skip collecting shuffle operators that are not associated with SMJ
collect(plan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c.originalPlan
}.flatMap(smj =>
collect(smj) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 12007cd94cd..07020f201fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -21,7 +21,7 @@ package org.apache.spark.sql.connector
import java.sql.Date
import java.util.Collections
-import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, IgnoreCometSuite, Row}
import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical
@@ -45,7 +45,8 @@ import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.tags.SlowSQLTest
@SlowSQLTest
-class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase {
+class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase
+ with IgnoreCometSuite {
import testImplicits._
before {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index ae1c0a86a14..1d3b914fd64 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.fs.permission.FsPermission
import org.mockito.Mockito.{mock, spy, when}
import org.apache.spark._
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, IgnoreComet, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator}
import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber}
@@ -256,7 +256,8 @@ class QueryExecutionErrorsSuite
}
test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " +
- "compatibility with Spark 2.4/3.2 in reading/writing dates") {
+ "compatibility with Spark 2.4/3.2 in reading/writing dates",
+ IgnoreComet("Comet doesn't completely support datetime rebase mode yet")) {
// Fail to read ancient datetime values.
withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
index 418ca3430bb..eb8267192f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
@@ -23,7 +23,7 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.internal.SQLConf
@@ -195,7 +195,7 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest {
}
}
- test("FileScan description") {
+ test("FileScan description", IgnoreComet("Comet doesn't use BatchScan")) {
Seq("json", "orc", "parquet").foreach { format =>
withTempPath { path =>
val dir = path.getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
index 743ec41dbe7..9f30d6c8e04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
@@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv
case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child)
case p: ProjectExec => isScanPlanTree(p.child)
case f: FilterExec => isScanPlanTree(f.child)
+ // Comet produces scan plan tree like:
+ // ColumnarToRow
+ // +- ReusedExchange
+ case _: ReusedExchangeExec => false
case _: LeafExecNode => true
case _ => false
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index de24b8c82b0..1f835481290 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, DataFrame, Row}
+import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -35,7 +35,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
-class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
+// Ignore this suite when Comet is enabled. This suite tests the Spark planner and Comet planner
+// comes out with too many difference. Simply ignoring this suite for now.
+class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper with IgnoreCometSuite {
import testImplicits._
setupTestData()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
index 9e9d717db3b..73de2b84938 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.execution
+import org.apache.comet.CometConf
+
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.comet.CometProjectExec
import org.apache.spark.sql.connector.SimpleWritableDataSource
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
@@ -34,7 +37,10 @@ abstract class RemoveRedundantProjectsSuiteBase
private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = {
withClue(df.queryExecution) {
val plan = df.queryExecution.executedPlan
- val actual = collectWithSubqueries(plan) { case p: ProjectExec => p }.size
+ val actual = collectWithSubqueries(plan) {
+ case p: ProjectExec => p
+ case p: CometProjectExec => p
+ }.size
assert(actual == expected)
}
}
@@ -134,12 +140,26 @@ abstract class RemoveRedundantProjectsSuiteBase
val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0")
df.collect()
val plan = df.queryExecution.executedPlan
- val numProjects = collectWithSubqueries(plan) { case p: ProjectExec => p }.length
+
+ val numProjects = collectWithSubqueries(plan) {
+ case p: ProjectExec => p
+ case p: CometProjectExec => p
+ }.length
// Create a new plan that reverse the GenerateExec output and add a new ProjectExec between
// GenerateExec and its child. This is to test if the ProjectExec is removed, the output of
// the query will be incorrect.
- val newPlan = stripAQEPlan(plan) transform {
+
+ // Comet-specific change to get original Spark plan before applying
+ // a transformation to add a new ProjectExec
+ var sparkPlan: SparkPlan = null
+ withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") {
+ val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0")
+ df.collect()
+ sparkPlan = df.queryExecution.executedPlan
+ }
+
+ val newPlan = stripAQEPlan(sparkPlan) transform {
case g @ GenerateExec(_, requiredChildOutput, _, _, child) =>
g.copy(requiredChildOutput = requiredChildOutput.reverse,
child = ProjectExec(requiredChildOutput.reverse, child))
@@ -151,6 +171,7 @@ abstract class RemoveRedundantProjectsSuiteBase
// The manually added ProjectExec node shouldn't be removed.
assert(collectWithSubqueries(newExecutedPlan) {
case p: ProjectExec => p
+ case p: CometProjectExec => p
}.size == numProjects + 1)
// Check the original plan's output and the new plan's output are the same.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
index 005e764cc30..92ec088efab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning}
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.joins.ShuffledJoin
import org.apache.spark.sql.internal.SQLConf
@@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase
private def checkNumSorts(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
- assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count)
+ assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count)
}
private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
index 47679ed7865..9ffbaecb98e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.internal.SQLConf
@@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase
private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) {
- case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
+ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
index a1147c16cc8..c7a29496328 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.{SparkArithmeticException, SparkException, SparkFileNotFoundException}
import org.apache.spark.sql._
+import org.apache.spark.sql.IgnoreCometNativeDataFusion
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Divide}
import org.apache.spark.sql.catalyst.parser.ParseException
@@ -968,7 +969,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
}
}
- test("alter temporary view should follow current storeAnalyzedPlanForView config") {
+ test("alter temporary view should follow current storeAnalyzedPlanForView config",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3314")) {
withTable("t") {
Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t")
withView("v1") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index eec396b2e39..bf3f1c769d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.TestUtils.assertSpilled
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, QueryTest, Row}
import org.apache.spark.sql.internal.SQLConf.{WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD, WINDOW_EXEC_BUFFER_SPILL_THRESHOLD}
import org.apache.spark.sql.test.SharedSparkSession
@@ -470,7 +470,7 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSparkSession {
Row(1, 3, null) :: Row(2, null, 4) :: Nil)
}
- test("test with low buffer spill threshold") {
+ test("test with low buffer spill threshold", IgnoreComet("Comet does not support spilling")) {
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
nums.createOrReplaceTempView("nums")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index b14f4a405f6..90bed10eca9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
+import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometNativeColumnarToRowExec}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -131,7 +132,11 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
- df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
+ df.queryExecution.executedPlan.collectFirst {
+ case p: ColumnarToRowExec => p
+ case p: CometColumnarToRowExec => p
+ case p: CometNativeColumnarToRowExec => p
+ }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 5a413c77754..207b66e1d7b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
+import org.apache.spark.sql.{Dataset, IgnoreCometSuite, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
@@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
// Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec
-class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
+class WholeStageCodegenSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite
with DisableAdaptiveExecutionSuite {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 2f8e401e743..a4f94417dcc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -27,9 +27,11 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
+import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike}
@@ -117,6 +119,7 @@ class AdaptiveQueryExecSuite
private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = {
collect(plan) {
case j: BroadcastHashJoinExec => j
+ case j: CometBroadcastHashJoinExec => j.originalPlan.asInstanceOf[BroadcastHashJoinExec]
}
}
@@ -129,36 +132,46 @@ class AdaptiveQueryExecSuite
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec =>
+ assert(j.originalPlan.isInstanceOf[SortMergeJoinExec])
+ j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
}
private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = {
collect(plan) {
case j: ShuffledHashJoinExec => j
+ case j: CometHashJoinExec => j.originalPlan.asInstanceOf[ShuffledHashJoinExec]
}
}
private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
collect(plan) {
case j: BaseJoinExec => j
+ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
+ case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
+ case c: CometBroadcastHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
}
}
private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = {
collect(plan) {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}
}
private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = {
collect(plan) {
case agg: BaseAggregateExec => agg
+ case agg: CometHashAggregateExec => agg.originalPlan.asInstanceOf[BaseAggregateExec]
}
}
private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = {
collect(plan) {
case l: CollectLimitExec => l
+ case l: CometCollectLimitExec => l.originalPlan.asInstanceOf[CollectLimitExec]
}
}
@@ -202,6 +215,7 @@ class AdaptiveQueryExecSuite
val parts = rdd.partitions
assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
}
+
assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead))
}
@@ -210,7 +224,7 @@ class AdaptiveQueryExecSuite
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
@@ -226,7 +240,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics
+ // checkNumLocalShuffleReads(adaptivePlan)
}
}
@@ -253,7 +268,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Reuse the parallelism of coalesced shuffle in local shuffle read") {
+ test("Reuse the parallelism of coalesced shuffle in local shuffle read",
+ IgnoreComet("Comet shuffle changes shuffle partition size")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -285,7 +301,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Reuse the default parallelism in local shuffle read") {
+ test("Reuse the default parallelism in local shuffle read",
+ IgnoreComet("Comet shuffle changes shuffle partition size")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -299,7 +316,8 @@ class AdaptiveQueryExecSuite
val localReads = collect(adaptivePlan) {
case read: AQEShuffleReadExec if read.isLocalRead => read
}
- assert(localReads.length == 2)
+ // Comet shuffle changes shuffle metrics
+ assert(localReads.length == 1)
val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD]
val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD]
// the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2
@@ -324,7 +342,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
- assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
+ assert(find(plan) { case p =>
+ p.isInstanceOf[SortMergeJoinExec] || p.isInstanceOf[CometSortMergeJoinExec]
+ }.isDefined)
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
@@ -338,7 +358,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
- assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(find(plan) { case p =>
+ p.isInstanceOf[BroadcastHashJoinExec] || p.isInstanceOf[CometBroadcastHashJoinExec]
+ }.isDefined)
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
@@ -348,7 +370,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Scalar subquery") {
+ test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -363,7 +385,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Scalar subquery in later stages") {
+ test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -379,7 +401,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins") {
+ test("multiple joins", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -424,7 +446,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins with aggregate") {
+ test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -469,7 +491,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins with aggregate 2") {
+ test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
@@ -515,7 +537,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Exchange reuse") {
+ test("Exchange reuse", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -534,7 +556,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Exchange reuse with subqueries") {
+ test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -565,7 +587,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics,
+ // so we can't check the number of local shuffle reads.
+ // checkNumLocalShuffleReads(adaptivePlan)
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
@@ -586,7 +610,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics,
+ // so we can't check the number of local shuffle reads.
+ // checkNumLocalShuffleReads(adaptivePlan)
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
@@ -595,7 +621,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Broadcast exchange reuse across subqueries") {
+ test("Broadcast exchange reuse across subqueries",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
@@ -690,7 +717,8 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// There is still a SMJ, and its two shuffles can't apply local read.
- checkNumLocalShuffleReads(adaptivePlan, 2)
+ // Comet shuffle changes shuffle metrics
+ // checkNumLocalShuffleReads(adaptivePlan, 2)
}
}
@@ -812,7 +840,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-29544: adaptive skew join with different join types") {
+ test("SPARK-29544: adaptive skew join with different join types",
+ IgnoreComet("Comet shuffle has different partition metrics")) {
Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") {
findTopLevelSortMergeJoin(plan)
@@ -1030,7 +1059,8 @@ class AdaptiveQueryExecSuite
}
}
- test("metrics of the shuffle read") {
+ test("metrics of the shuffle read",
+ IgnoreComet("Comet shuffle changes the metrics")) {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT key FROM testData GROUP BY key")
@@ -1625,7 +1655,7 @@ class AdaptiveQueryExecSuite
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
assert(collect(adaptivePlan) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}.length == 1)
}
}
@@ -1705,7 +1735,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-33551: Do not use AQE shuffle read for repartition") {
+ test("SPARK-33551: Do not use AQE shuffle read for repartition",
+ IgnoreComet("Comet shuffle changes partition size")) {
def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
find(plan) {
case s: ShuffleExchangeLike =>
@@ -1890,6 +1921,9 @@ class AdaptiveQueryExecSuite
def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = {
assert(collect(ds.queryExecution.executedPlan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ case c: CometShuffleExchangeExec
+ if c.originalPlan.shuffleOrigin == origin &&
+ c.originalPlan.numPartitions == 2 => c
}.size == 1)
ds.collect()
val plan = ds.queryExecution.executedPlan
@@ -1898,6 +1932,9 @@ class AdaptiveQueryExecSuite
}.isEmpty)
assert(collect(plan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ case c: CometShuffleExchangeExec
+ if c.originalPlan.shuffleOrigin == origin &&
+ c.originalPlan.numPartitions == 2 => c
}.size == 1)
checkAnswer(ds, testData)
}
@@ -2054,7 +2091,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35264: Support AQE side shuffled hash join formula") {
+ test("SPARK-35264: Support AQE side shuffled hash join formula",
+ IgnoreComet("Comet shuffle changes the partition size")) {
withTempView("t1", "t2") {
def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
Seq("100", "100000").foreach { size =>
@@ -2140,7 +2178,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") {
+ test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -2239,7 +2278,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
val shuffles1 = collect(adaptive1) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles1.size == 3)
// shuffles1.head is the top-level shuffle under the Aggregate operator
@@ -2252,7 +2291,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
s"JOIN skewData2 ON key1 = key2")
val shuffles2 = collect(adaptive2) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
if (hasRequiredDistribution) {
assert(shuffles2.size == 3)
@@ -2286,7 +2325,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35794: Allow custom plugin for cost evaluator") {
+ test("SPARK-35794: Allow custom plugin for cost evaluator",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
CostEvaluator.instantiate(
classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
intercept[IllegalArgumentException] {
@@ -2417,7 +2457,8 @@ class AdaptiveQueryExecSuite
}
test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " +
- "resulting in potentially inaccurate data") {
+ "resulting in potentially inaccurate data",
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/1501")) {
withTable("t3") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -2452,6 +2493,7 @@ class AdaptiveQueryExecSuite
val (_, adaptive) = runAdaptiveAndVerifyResult(query)
assert(adaptive.collect {
case sort: SortExec => sort
+ case sort: CometSortExec => sort
}.size == 1)
val read = collect(adaptive) {
case read: AQEShuffleReadExec => read
@@ -2469,7 +2511,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-37357: Add small partition factor for rebalance partitions") {
+ test("SPARK-37357: Add small partition factor for rebalance partitions",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
@@ -2581,7 +2624,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value2 = value3")
val shuffles1 = collect(adaptive1) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles1.size == 4)
val smj1 = findTopLevelSortMergeJoin(adaptive1)
@@ -2592,7 +2635,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value1 = value3")
val shuffles2 = collect(adaptive2) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles2.size == 4)
val smj2 = findTopLevelSortMergeJoin(adaptive2)
@@ -2850,6 +2893,7 @@ class AdaptiveQueryExecSuite
}.size == (if (firstAccess) 1 else 0))
assert(collect(initialExecutedPlan) {
case s: SortExec => s
+ case s: CometSortExec => s
}.size == (if (firstAccess) 2 else 0))
assert(collect(initialExecutedPlan) {
case i: InMemoryTableScanLike => i
@@ -2980,7 +3024,9 @@ class AdaptiveQueryExecSuite
val plan = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
assert(plan.inputPlan.isInstanceOf[TakeOrderedAndProjectExec])
- assert(plan.finalPhysicalPlan.isInstanceOf[WindowExec])
+ assert(
+ plan.finalPhysicalPlan.isInstanceOf[WindowExec] ||
+ plan.finalPhysicalPlan.find(_.isInstanceOf[CometWindowExec]).nonEmpty)
plan.inputPlan.output.zip(plan.finalPhysicalPlan.output).foreach { case (o1, o2) =>
assert(o1.semanticEquals(o2), "Different output column order after AQE optimization")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index fd52d038ca6..154c800be67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Concat
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
@@ -884,6 +885,8 @@ abstract class SchemaPruningSuite
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case scan: FileSourceScanExec => scan.requiredSchema
+ case scan: CometScanExec => scan.requiredSchema
+ case scan: CometNativeScanExec => scan.requiredSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index 5fd27410dcb..468abb1543a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.execution.{QueryExecution, SortExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.internal.SQLConf
@@ -243,6 +244,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _),
@@ -290,6 +292,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 0b6fdef4f74..5b18c55da4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
import org.mockito.Mockito.{mock, when}
import org.apache.spark.SparkException
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.{DataFrame, IgnoreCometSuite, QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.functions.col
@@ -38,7 +38,9 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class BinaryFileFormatSuite extends QueryTest with SharedSparkSession {
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
+class BinaryFileFormatSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite {
import BinaryFileFormat._
private var testDir: String = _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
index 07e2849ce6f..3e73645b638 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -28,7 +28,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.TestUtils
import org.apache.spark.memory.MemoryMode
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{IgnoreComet, Row}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -201,7 +201,8 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess
}
}
- test("parquet v2 pages - rle encoding for boolean value columns") {
+ test("parquet v2 pages - rle encoding for boolean value columns",
+ IgnoreComet("Comet doesn't support RLE encoding yet")) {
val extraOptions = Map[String, String](
ParquetOutputFormat.WRITER_VERSION -> ParquetProperties.WriterVersion.PARQUET_2_0.toString
)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 8e88049f51e..49f2001dc6b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -1095,7 +1095,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// When a filter is pushed to Parquet, Parquet can apply it to every row.
// So, we can check the number of rows returned from the Parquet
// to make sure our filter pushdown work.
- assert(stripSparkFilter(df).count == 1)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ assert(stripSparkFilter(df).count == 1)
+ }
}
}
}
@@ -1498,7 +1502,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("Filters should be pushed down for vectorized Parquet reader at row group level") {
+ test("Filters should be pushed down for vectorized Parquet reader at row group level",
+ IgnoreCometNativeScan("Native scans do not support the tested accumulator")) {
import testImplicits._
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true",
@@ -1548,7 +1553,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("SPARK-31026: Parquet predicate pushdown for fields having dots in the names") {
+ test("SPARK-31026: Parquet predicate pushdown for fields having dots in the names",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3320")) {
import testImplicits._
withAllParquetReaders {
@@ -1580,13 +1586,18 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// than the total length but should not be a single record.
// Note that, if record level filtering is enabled, it should be a single record.
// If no filter is pushed down to Parquet, it should be the total length of data.
- assert(actual > 1 && actual < data.length)
+ // Only enable Comet test iff it's scan only, since with native execution
+ // `stripSparkFilter` can't remove the native filter
+ if (!isCometEnabled || isCometScanOnly) {
+ assert(actual > 1 && actual < data.length)
+ }
}
}
}
}
- test("Filters should be pushed down for Parquet readers at row group level") {
+ test("Filters should be pushed down for Parquet readers at row group level",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3320")) {
import testImplicits._
withSQLConf(
@@ -1607,7 +1618,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// than the total length but should not be a single record.
// Note that, if record level filtering is enabled, it should be a single record.
// If no filter is pushed down to Parquet, it should be the total length of data.
- assert(actual > 1 && actual < data.length)
+ // Only enable Comet test iff it's scan only, since with native execution
+ // `stripSparkFilter` can't remove the native filter
+ if (!isCometEnabled || isCometScanOnly) {
+ assert(actual > 1 && actual < data.length)
+ }
}
}
}
@@ -1699,7 +1714,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
(attr, value) => sources.StringContains(attr, value))
}
- test("filter pushdown - StringPredicate") {
+ test("filter pushdown - StringPredicate", IgnoreCometNativeScan("cannot be pushed down")) {
import testImplicits._
// keep() should take effect on StartsWith/EndsWith/Contains
Seq(
@@ -1743,7 +1758,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("SPARK-17091: Convert IN predicate to Parquet filter push-down") {
+ test("SPARK-17091: Convert IN predicate to Parquet filter push-down",
+ IgnoreCometNativeScan("Comet has different push-down behavior")) {
val schema = StructType(Seq(
StructField("a", IntegerType, nullable = false)
))
@@ -1933,7 +1949,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("SPARK-25207: exception when duplicate fields in case-insensitive mode") {
+ test("SPARK-25207: exception when duplicate fields in case-insensitive mode",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
withTempPath { dir =>
val count = 10
val tableName = "spark_25207"
@@ -1984,7 +2001,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("Support Parquet column index") {
+ test("Support Parquet column index",
+ IgnoreComet("Comet doesn't support Parquet column index yet")) {
// block 1:
// null count min max
// page-0 0 0 99
@@ -2044,7 +2062,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("SPARK-34562: Bloom filter push down") {
+ test("SPARK-34562: Bloom filter push down",
+ IgnoreCometNativeScan("Native scans do not support the tested accumulator")) {
withTempPath { dir =>
val path = dir.getCanonicalPath
spark.range(100).selectExpr("id * 2 AS id")
@@ -2276,7 +2295,11 @@ class ParquetV1FilterSuite extends ParquetFilterSuite {
assert(pushedParquetFilters.exists(_.getClass === filterClass),
s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.")
- checker(stripSparkFilter(query), expected)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ checker(stripSparkFilter(query), expected)
+ }
} else {
assert(selectedFilters.isEmpty, "There is filter pushed down")
}
@@ -2336,7 +2359,11 @@ class ParquetV2FilterSuite extends ParquetFilterSuite {
assert(pushedParquetFilters.exists(_.getClass === filterClass),
s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.")
- checker(stripSparkFilter(query), expected)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ checker(stripSparkFilter(query), expected)
+ }
case _ =>
throw new AnalysisException("Can not match ParquetTable in the query.")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 8ed9ef1630e..f312174b182 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -1064,7 +1064,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
- test("SPARK-35640: read binary as timestamp should throw schema incompatible error") {
+ test("SPARK-35640: read binary as timestamp should throw schema incompatible error",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
val data = (1 to 4).map(i => Tuple1(i.toString))
val readSchema = StructType(Seq(StructField("_1", DataTypes.TimestampType)))
@@ -1075,7 +1076,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
- test("SPARK-35640: int as long should throw schema incompatible error") {
+ test("SPARK-35640: int as long should throw schema incompatible error",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
val data = (1 to 4).map(i => Tuple1(i))
val readSchema = StructType(Seq(StructField("_1", DataTypes.LongType)))
@@ -1345,7 +1347,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
- test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings") {
+ test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings",
+ IgnoreComet("Comet doesn't support DELTA encoding yet")) {
withAllParquetReaders {
checkAnswer(
// "fruit" column in this file is encoded using DELTA_LENGTH_BYTE_ARRAY.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index f6472ba3d9d..ce39ebb52e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -185,7 +185,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
}
}
- test("SPARK-36182: can't read TimestampLTZ as TimestampNTZ") {
+ test("SPARK-36182: can't read TimestampLTZ as TimestampNTZ",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
val data = (1 to 1000).map { i =>
val ts = new java.sql.Timestamp(i)
Row(ts)
@@ -998,7 +999,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
}
}
- test("SPARK-26677: negated null-safe equality comparison should not filter matched row groups") {
+ test("SPARK-26677: negated null-safe equality comparison should not filter matched row groups",
+ IgnoreCometNativeScan("Native scans had the filter pushed into DF operator, cannot strip")) {
withAllParquetReaders {
withTempPath { path =>
// Repeated values for dictionary encoding.
@@ -1051,7 +1053,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
testMigration(fromTsType = "TIMESTAMP_MICROS", toTsType = "INT96")
}
- test("SPARK-34212 Parquet should read decimals correctly") {
+ test("SPARK-34212 Parquet should read decimals correctly",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
def readParquet(schema: String, path: File): DataFrame = {
spark.read.schema(schema).parquet(path.toString)
}
@@ -1067,7 +1070,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
checkAnswer(readParquet(schema, path), df)
}
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false",
+ "spark.comet.enabled" -> "false") {
val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
checkAnswer(readParquet(schema1, path), df)
val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
@@ -1089,7 +1093,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
val df = sql(s"SELECT 1 a, 123456 b, ${Int.MaxValue.toLong * 10} c, CAST('1.2' AS BINARY) d")
df.write.parquet(path.toString)
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false",
+ "spark.comet.enabled" -> "false") {
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 123456.0"))
@@ -1133,7 +1138,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
}
}
- test("row group skipping doesn't overflow when reading into larger type") {
+ test("row group skipping doesn't overflow when reading into larger type",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
withTempPath { path =>
Seq(0).toDF("a").write.parquet(path.toString)
// The vectorized and non-vectorized readers will produce different exceptions, we don't need
@@ -1148,7 +1154,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
.where(s"a < ${Long.MaxValue}")
.collect()
}
- assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException])
+ assert(exception.getMessage.contains("Column: [a], Expected: bigint, Found: INT32"))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
index 4f906411345..6cc69f7e915 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
@@ -21,7 +21,7 @@ import java.nio.file.{Files, Paths, StandardCopyOption}
import java.sql.{Date, Timestamp}
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
-import org.apache.spark.sql.{QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY}
+import org.apache.spark.sql.{IgnoreCometSuite, QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.LegacyBehaviorPolicy.{CORRECTED, EXCEPTION, LEGACY}
@@ -30,9 +30,11 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType.{INT96,
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.tags.SlowSQLTest
+// Comet is disabled for this suite because it doesn't support datetime rebase mode
abstract class ParquetRebaseDatetimeSuite
extends QueryTest
with ParquetTest
+ with IgnoreCometSuite
with SharedSparkSession {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
index 27c2a2148fd..df04a15fb1f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
@@ -20,12 +20,14 @@ import java.io.File
import scala.collection.JavaConverters._
+import org.apache.comet.CometConf
import org.apache.hadoop.fs.Path
import org.apache.parquet.column.ParquetProperties._
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat}
import org.apache.parquet.hadoop.ParquetWriter.DEFAULT_BLOCK_SIZE
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -172,6 +174,8 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession {
private def testRowIndexGeneration(label: String, conf: RowIndexTestConf): Unit = {
test (s"$label - ${conf.desc}") {
+ // native_datafusion Parquet scan does not support row index generation.
+ assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
withSQLConf(conf.sqlConfs: _*) {
withTempPath { path =>
// Read row index using _metadata.row_index if that is supported by the file format.
@@ -243,6 +247,12 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession {
case f: FileSourceScanExec =>
numPartitions += f.inputRDD.partitions.length
numOutputRows += f.metrics("numOutputRows").value
+ case b: CometScanExec =>
+ numPartitions += b.inputRDD.partitions.length
+ numOutputRows += b.metrics("numOutputRows").value
+ case b: CometBatchScanExec =>
+ numPartitions += b.inputRDD.partitions.length
+ numOutputRows += b.metrics("numOutputRows").value
case _ =>
}
assert(numPartitions > 0)
@@ -301,6 +311,8 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession {
val conf = RowIndexTestConf(useDataSourceV2 = useDataSourceV2)
test(s"invalid row index column type - ${conf.desc}") {
+ // native_datafusion Parquet scan does not support row index generation.
+ assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
withSQLConf(conf.sqlConfs: _*) {
withTempPath{ path =>
val df = spark.range(0, 10, 1, 1).toDF("id")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
index 5c0b7def039..151184bc98c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.comet.CometBatchScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -56,6 +57,7 @@ class ParquetV2SchemaPruningSuite extends ParquetSchemaPruningSuite {
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case scan: BatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema
+ case scan: CometBatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 3f47c5e506f..92a5eafec84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -27,6 +27,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.Type._
import org.apache.spark.SparkException
+import org.apache.spark.sql.{IgnoreComet, IgnoreCometNativeDataFusion}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
import org.apache.spark.sql.functions.desc
@@ -1036,7 +1037,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
e
}
- test("schema mismatch failure error message for parquet reader") {
+ test("schema mismatch failure error message for parquet reader",
+ IgnoreComet("Comet doesn't work with vectorizedReaderEnabled = false")) {
withTempPath { dir =>
val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false)
val expectedMessage = "Encountered error while reading file"
@@ -1046,7 +1048,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
}
}
- test("schema mismatch failure error message for parquet vectorized reader") {
+ test("schema mismatch failure error message for parquet vectorized reader",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
withTempPath { dir =>
val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = true)
assert(e.getCause.isInstanceOf[SparkException])
@@ -1087,7 +1090,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
}
}
- test("SPARK-45604: schema mismatch failure error on timestamp_ntz to array<timestamp_ntz>") {
+ test("SPARK-45604: schema mismatch failure error on timestamp_ntz to array<timestamp_ntz>",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) {
import testImplicits._
withTempPath { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index b8f3ea3c6f3..bbd44221288 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.debug
import java.io.ByteArrayOutputStream
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.IgnoreComet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -125,7 +126,8 @@ class DebuggingSuite extends DebuggingSuiteBase with DisableAdaptiveExecutionSui
| id LongType: {}""".stripMargin))
}
- test("SPARK-28537: DebugExec cannot debug columnar related queries") {
+ test("SPARK-28537: DebugExec cannot debug columnar related queries",
+ IgnoreComet("Comet does not use FileScan")) {
withTempPath { workDir =>
val workDirPath = workDir.getAbsolutePath
val input = spark.range(5).toDF("id")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 5cdbdc27b32..307fba16578 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -46,8 +46,10 @@ import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.{AccumulatorContext, JsonProtocol}
// Disable AQE because metric info is different with AQE on/off
+// This test suite runs tests against the metrics of physical operators.
+// Disabling it for Comet because the metrics are different with Comet enabled.
class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
- with DisableAdaptiveExecutionSuite {
+ with DisableAdaptiveExecutionSuite with IgnoreCometSuite {
import testImplicits._
/**
@@ -765,7 +767,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
}
}
- test("SPARK-26327: FileSourceScanExec metrics") {
+ test("SPARK-26327: FileSourceScanExec metrics",
+ IgnoreComet("Spark uses row-based Parquet reader while Comet is vectorized")) {
withTable("testDataForScan") {
spark.range(10).selectExpr("id", "id % 3 as p")
.write.partitionBy("p").saveAsTable("testDataForScan")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
index 0ab8691801d..7b81f3a8f6d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.python
+import org.apache.spark.sql.IgnoreCometNativeDataFusion
import org.apache.spark.sql.catalyst.plans.logical.{ArrowEvalPython, BatchEvalPython, Limit, LocalLimit}
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
@@ -93,7 +95,8 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
assert(arrowEvalNodes.size == 2)
}
- test("Python UDF should not break column pruning/filter pushdown -- Parquet V1") {
+ test("Python UDF should not break column pruning/filter pushdown -- Parquet V1",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") {
withTempPath { f =>
spark.range(10).select($"id".as("a"), $"id".as("b"))
@@ -108,6 +111,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
+ case scan: CometScanExec => scan
}
assert(scanNodes.length == 1)
assert(scanNodes.head.output.map(_.name) == Seq("a"))
@@ -120,11 +124,16 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
+ case scan: CometScanExec => scan
}
assert(scanNodes.length == 1)
// $"a" is not null and $"a" > 1
- assert(scanNodes.head.dataFilters.length == 2)
- assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
+ val dataFilters = scanNodes.head match {
+ case scan: FileSourceScanExec => scan.dataFilters
+ case scan: CometScanExec => scan.dataFilters
+ }
+ assert(dataFilters.length == 2)
+ assert(dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
}
}
}
@@ -145,6 +154,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: BatchScanExec => scan
+ case scan: CometBatchScanExec => scan
}
assert(scanNodes.length == 1)
assert(scanNodes.head.output.map(_.name) == Seq("a"))
@@ -157,6 +167,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: BatchScanExec => scan
+ case scan: CometBatchScanExec => scan
}
assert(scanNodes.length == 1)
// $"a" is not null and $"a" > 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
index d083cac48ff..3c11bcde807 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
@@ -37,8 +37,10 @@ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException,
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.Utils
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
class AsyncProgressTrackingMicroBatchExecutionSuite
- extends StreamTest with BeforeAndAfter with Matchers {
+ extends StreamTest with BeforeAndAfter with Matchers with IgnoreCometSuite {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 746f289c393..7a6a88a9fce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -19,16 +19,19 @@ package org.apache.spark.sql.sources
import scala.util.Random
+import org.apache.comet.CometConf
+
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
import org.apache.spark.sql.execution.datasources.BucketingUtils
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -102,12 +105,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
}
- private def getFileScan(plan: SparkPlan): FileSourceScanExec = {
- val fileScan = collect(plan) { case f: FileSourceScanExec => f }
+ private def getFileScan(plan: SparkPlan): SparkPlan = {
+ val fileScan = collect(plan) {
+ case f: FileSourceScanExec => f
+ case f: CometScanExec => f
+ case f: CometNativeScanExec => f
+ }
assert(fileScan.nonEmpty, plan)
fileScan.head
}
+ private def getBucketScan(plan: SparkPlan): Boolean = getFileScan(plan) match {
+ case fs: FileSourceScanExec => fs.bucketedScan
+ case bs: CometScanExec => bs.bucketedScan
+ case ns: CometNativeScanExec => ns.bucketedScan
+ }
+
// To verify if the bucket pruning works, this function checks two conditions:
// 1) Check if the pruned buckets (before filtering) are empty.
// 2) Verify the final result is the same as the expected one
@@ -156,7 +169,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition)
.queryExecution.executedPlan
val fileScan = getFileScan(planWithoutBucketedScan)
- assert(!fileScan.bucketedScan, s"except no bucketed scan but found\n$fileScan")
+ val bucketedScan = getBucketScan(planWithoutBucketedScan)
+ assert(!bucketedScan, s"except no bucketed scan but found\n$fileScan")
val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -452,28 +466,54 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case b: CometSortMergeJoinExec =>
+ b.originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
} else {
val executedPlan = joined.queryExecution.executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case ColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometNativeColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}
// check existence of shuffle
assert(
- joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft,
+ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft,
s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}")
assert(
- joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight,
+ joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight,
s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}")
// check existence of sort
assert(
- joinOperator.left.exists(_.isInstanceOf[SortExec]) == sortLeft,
+ joinOperator.left.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] &&
+ op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortLeft,
s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}")
assert(
- joinOperator.right.exists(_.isInstanceOf[SortExec]) == sortRight,
+ joinOperator.right.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] &&
+ op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortRight,
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")
// check the output partitioning
@@ -836,11 +876,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
val scanDF = spark.table("bucketed_table").select("j")
- assert(!getFileScan(scanDF.queryExecution.executedPlan).bucketedScan)
+ assert(!getBucketScan(scanDF.queryExecution.executedPlan))
checkAnswer(scanDF, df1.select("j"))
val aggDF = spark.table("bucketed_table").groupBy("j").agg(max("k"))
- assert(!getFileScan(aggDF.queryExecution.executedPlan).bucketedScan)
+ assert(!getBucketScan(aggDF.queryExecution.executedPlan))
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -895,7 +935,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
test("SPARK-29655 Read bucketed tables obeys spark.sql.shuffle.partitions") {
+ // Range partitioning uses random samples, so per-partition comparisons do not always yield
+ // the same results. Disable Comet native range partitioning.
withSQLConf(
+ CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false",
SQLConf.SHUFFLE_PARTITIONS.key -> "5",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") {
val bucketSpec = Some(BucketSpec(6, Seq("i", "j"), Nil))
@@ -914,7 +957,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
test("SPARK-32767 Bucket join should work if SHUFFLE_PARTITIONS larger than bucket number") {
+ // Range partitioning uses random samples, so per-partition comparisons do not always yield
+ // the same results. Disable Comet native range partitioning.
withSQLConf(
+ CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false",
SQLConf.SHUFFLE_PARTITIONS.key -> "9",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10") {
@@ -944,7 +990,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
test("bucket coalescing eliminates shuffle") {
+ // Range partitioning uses random samples, so per-partition comparisons do not always yield
+ // the same results. Disable Comet native range partitioning.
withSQLConf(
+ CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false",
SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
// The side with bucketedTableTestSpec1 will be coalesced to have 4 output partitions.
@@ -1029,15 +1078,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
Seq(true, false).foreach { aqeEnabled =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
val plan = sql(query).queryExecution.executedPlan
- val shuffles = collect(plan) { case s: ShuffleExchangeExec => s }
+ val shuffles = collect(plan) { case s: ShuffleExchangeLike => s }
assert(shuffles.length == expectedNumShuffles)
val scans = collect(plan) {
case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f
+ case b: CometScanExec if b.optionalNumCoalescedBuckets.isDefined => b
+ case b: CometNativeScanExec if b.optionalNumCoalescedBuckets.isDefined => b
}
if (expectedCoalescedNumBuckets.isDefined) {
assert(scans.length == 1)
- assert(scans.head.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ scans.head match {
+ case f: FileSourceScanExec =>
+ assert(f.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ case b: CometScanExec =>
+ assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ case b: CometNativeScanExec =>
+ assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ }
} else {
assert(scans.isEmpty)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 6f897a9c0b7..b0723634f68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import java.io.File
import org.apache.spark.SparkException
+import org.apache.spark.sql.IgnoreCometSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTableType}
import org.apache.spark.sql.catalyst.parser.ParseException
@@ -27,7 +28,10 @@ import org.apache.spark.sql.internal.SQLConf.BUCKETING_MAX_BUCKETS
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
-class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession {
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
+class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession
+ with IgnoreCometSuite {
import testImplicits._
protected override lazy val sql = spark.sql _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
index d675503a8ba..f220892396e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.sources
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
@@ -68,7 +69,11 @@ abstract class DisableUnnecessaryBucketedScanSuite
def checkNumBucketedScan(query: String, expectedNumBucketedScan: Int): Unit = {
val plan = sql(query).queryExecution.executedPlan
- val bucketedScan = collect(plan) { case s: FileSourceScanExec if s.bucketedScan => s }
+ val bucketedScan = collect(plan) {
+ case s: FileSourceScanExec if s.bucketedScan => s
+ case s: CometScanExec if s.bucketedScan => s
+ case s: CometNativeScanExec if s.bucketedScan => s
+ }
assert(bucketedScan.length == expectedNumBucketedScan)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 7f6fa2a123e..c778b4e2c48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -35,6 +35,7 @@ import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.catalyst.util.stringToFile
+import org.apache.spark.sql.comet.CometBatchScanExec
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, FileScan, FileTable}
@@ -777,6 +778,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
val fileScan = df.queryExecution.executedPlan.collect {
case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] =>
batch.scan.asInstanceOf[FileScan]
+ case batch: CometBatchScanExec if batch.scan.isInstanceOf[FileScan] =>
+ batch.scan.asInstanceOf[FileScan]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index c97979a57a5..45a998db0e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpressi
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.comet.CometLocalLimitExec
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
@@ -1114,11 +1115,12 @@ class StreamSuite extends StreamTest {
val localLimits = execPlan.collect {
case l: LocalLimitExec => l
case l: StreamingLocalLimitExec => l
+ case l: CometLocalLimitExec => l
}
require(
localLimits.size == 1,
- s"Cant verify local limit optimization with this plan:\n$execPlan")
+ s"Cant verify local limit optimization ${localLimits.size} with this plan:\n$execPlan")
if (expectStreamingLimit) {
assert(
@@ -1126,7 +1128,8 @@ class StreamSuite extends StreamTest {
s"Local limit was not StreamingLocalLimitExec:\n$execPlan")
} else {
assert(
- localLimits.head.isInstanceOf[LocalLimitExec],
+ localLimits.head.isInstanceOf[LocalLimitExec] ||
+ localLimits.head.isInstanceOf[CometLocalLimitExec],
s"Local limit was not LocalLimitExec:\n$execPlan")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
index b4c4ec7acbf..20579284856 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils
import org.scalatest.Assertions
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec}
import org.apache.spark.sql.functions.count
@@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = query.lastExecution.executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}
val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
@@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}
val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index aad91601758..201083bd621 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper}
import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId}
import org.apache.spark.sql.functions._
@@ -619,14 +619,27 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)
- assert(query.lastExecution.executedPlan.collect {
- case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _,
- ShuffleExchangeExec(opA: HashPartitioning, _, _, _),
- ShuffleExchangeExec(opB: HashPartitioning, _, _, _))
- if partitionExpressionsColumns(opA.expressions) === Seq("a", "b")
- && partitionExpressionsColumns(opB.expressions) === Seq("a", "b")
- && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j
- }.size == 1)
+ val join = query.lastExecution.executedPlan.collect {
+ case j: StreamingSymmetricHashJoinExec => j
+ }.head
+ val opA = join.left.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning.asInstanceOf[HashPartitioning]
+ }.head
+ val opB = join.right.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning]
+ }.head
+ assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions)
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index b5cf13a9c12..ac17603fb7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -36,7 +36,7 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkException, TestUtils}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, IgnoreCometNativeDataFusion, Row, SaveMode}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid}
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LocalRelation}
@@ -660,7 +660,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
)
}
- test("SPARK-41198: input row calculation with CTE") {
+ test("SPARK-41198: input row calculation with CTE",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3315")) {
withTable("parquet_tbl", "parquet_streaming_tbl") {
spark.range(0, 10).selectExpr("id AS col1", "id AS col2")
.write.format("parquet").saveAsTable("parquet_tbl")
@@ -712,7 +713,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
}
- test("SPARK-41199: input row calculation with mixed-up of DSv1 and DSv2 streaming sources") {
+ test("SPARK-41199: input row calculation with mixed-up of DSv1 and DSv2 streaming sources",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3315")) {
withTable("parquet_streaming_tbl") {
val streamInput = MemoryStream[Int]
val streamDf = streamInput.toDF().selectExpr("value AS key", "value AS value_stream")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala
index 8f099c31e6b..ce4b7ad25b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.PatienceConfiguration.Timeout
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.{IgnoreCometNativeDataFusion, SaveMode}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.test.{InMemoryStreamTable, InMemoryStreamTableCatalog}
@@ -42,7 +42,8 @@ class StreamingSelfUnionSuite extends StreamTest with BeforeAndAfter {
sqlContext.streams.active.foreach(_.stop())
}
- test("self-union, DSv1, read via DataStreamReader API") {
+ test("self-union, DSv1, read via DataStreamReader API",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3401")) {
withTempPath { dir =>
val dataLocation = dir.getAbsolutePath
spark.range(1, 4).write.format("parquet").save(dataLocation)
@@ -66,7 +67,8 @@ class StreamingSelfUnionSuite extends StreamTest with BeforeAndAfter {
}
}
- test("self-union, DSv1, read via table API") {
+ test("self-union, DSv1, read via table API",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3401")) {
withTable("parquet_streaming_tbl") {
spark.sql("CREATE TABLE parquet_streaming_tbl (key integer) USING parquet")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index abe606ad9c1..2d930b64cca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -22,7 +22,7 @@ import java.util
import org.scalatest.BeforeAndAfter
-import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
@@ -327,7 +327,8 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter {
}
}
- test("explain with table on DSv1 data source") {
+ test("explain with table on DSv1 data source",
+ IgnoreComet("Comet explain output is different")) {
val tblSourceName = "tbl_src"
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index e937173a590..7d20538bc68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,6 +27,7 @@ import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.util.control.NonFatal
+import org.apache.comet.CometConf
import org.apache.hadoop.fs.Path
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfterAll, Suite, Tag}
@@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.execution.datasources.DataSourceUtils
@@ -119,6 +121,34 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
(implicit pos: Position): Unit = {
+ // Check Comet skip tags first, before DisableAdaptiveExecution handling
+ if (isCometEnabled && testTags.exists(_.isInstanceOf[IgnoreComet])) {
+ ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
+ return
+ }
+ if (isCometEnabled) {
+ val cometScanImpl = CometConf.COMET_NATIVE_SCAN_IMPL.get(conf)
+ val isNativeIcebergCompat = cometScanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT ||
+ cometScanImpl == CometConf.SCAN_AUTO
+ val isNativeDataFusion = cometScanImpl == CometConf.SCAN_NATIVE_DATAFUSION ||
+ cometScanImpl == CometConf.SCAN_AUTO
+ if (isNativeIcebergCompat &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeIcebergCompat])) {
+ ignore(testName + " (disabled for NATIVE_ICEBERG_COMPAT)", testTags: _*)(testFun)
+ return
+ }
+ if (isNativeDataFusion &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeDataFusion])) {
+ ignore(testName + " (disabled for NATIVE_DATAFUSION)", testTags: _*)(testFun)
+ return
+ }
+ if ((isNativeDataFusion || isNativeIcebergCompat) &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeScan])) {
+ ignore(testName + " (disabled for NATIVE_DATAFUSION and NATIVE_ICEBERG_COMPAT)",
+ testTags: _*)(testFun)
+ return
+ }
+ }
if (testTags.exists(_.isInstanceOf[DisableAdaptiveExecution])) {
super.test(testName, testTags: _*) {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
@@ -242,6 +272,29 @@ private[sql] trait SQLTestUtilsBase
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
+ /**
+ * Whether Comet extension is enabled
+ */
+ protected def isCometEnabled: Boolean = SparkSession.isCometEnabled
+
+ /**
+ * Whether to enable ansi mode This is only effective when
+ * [[isCometEnabled]] returns true.
+ */
+ protected def enableCometAnsiMode: Boolean = {
+ val v = System.getenv("ENABLE_COMET_ANSI_MODE")
+ v != null && v.toBoolean
+ }
+
+ /**
+ * Whether Spark should only apply Comet scan optimization. This is only effective when
+ * [[isCometEnabled]] returns true.
+ */
+ protected def isCometScanOnly: Boolean = {
+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
+ v != null && v.toBoolean
+ }
+
protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
SparkSession.setActiveSession(spark)
super.withSQLConf(pairs: _*)(f)
@@ -435,6 +488,8 @@ private[sql] trait SQLTestUtilsBase
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}
spark.internalCreateDataFrame(withoutFilters.execute(), schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index ed2e309fa07..a5ea58146ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -74,6 +74,31 @@ trait SharedSparkSessionBase
// this rule may potentially block testing of other optimization rules such as
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+ // Enable Comet if `ENABLE_COMET` environment variable is set
+ if (isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
+ .set("spark.comet.parquet.respectFilterPushdown", "true")
+
+ if (!isCometScanOnly) {
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
+ .set("spark.comet.memoryOverhead", "10g")
+ } else {
+ conf
+ .set("spark.comet.exec.enabled", "false")
+ .set("spark.comet.exec.shuffle.enabled", "false")
+ }
+
+ if (enableCometAnsiMode) {
+ conf
+ .set("spark.sql.ansi.enabled", "true")
+ }
+ }
conf.set(
StaticSQLConf.WAREHOUSE_PATH,
conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName)
diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
index c63c748953f..7edca9c93a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
@@ -45,7 +45,7 @@ class SqlResourceWithActualMetricsSuite
import testImplicits._
// Exclude nodes which may not have the metrics
- val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject")
+ val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject", "RowToColumnar")
implicit val formats = new DefaultFormats {
override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
index 52abd248f3a..7a199931a08 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.hive.execution.HiveTableScanExec
@@ -35,6 +36,9 @@ abstract class DynamicPartitionPruningHiveScanSuiteBase
case s: FileSourceScanExec => s.partitionFilters.collect {
case d: DynamicPruningExpression => d.child
}
+ case s: CometScanExec => s.partitionFilters.collect {
+ case d: DynamicPruningExpression => d.child
+ }
case h: HiveTableScanExec => h.partitionPruningPred.collect {
case d: DynamicPruningExpression => d.child
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala
index de3b1ffccf0..2a76d127093 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala
@@ -23,14 +23,15 @@ import java.util.concurrent.{Executors, TimeUnit}
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.metrics.source.HiveCatalogMetrics
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{IgnoreCometSuite, QueryTest}
import org.apache.spark.sql.execution.datasources.FileStatusCache
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
class PartitionedTablePerfStatsSuite
- extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach {
+ extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach
+ with IgnoreCometSuite {
override def beforeEach(): Unit = {
super.beforeEach()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index f3be79f9022..b4b1ea8dbc4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -34,7 +34,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.io.{LongWritable, Writable}
import org.apache.spark.{SparkException, SparkFiles, TestUtils}
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreCometNativeDataFusion, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
@@ -448,7 +448,8 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}
- test("SPARK-11522 select input_file_name from non-parquet table") {
+ test("SPARK-11522 select input_file_name from non-parquet table",
+ IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) {
withTempDir { tempDir =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 6160c3e5f6c..0956d7d9edc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
import java.util.{Locale, Set}
import com.google.common.io.Files
+import org.apache.comet.CometConf
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{SparkException, TestUtils}
@@ -838,8 +839,13 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi
}
test("SPARK-2554 SumDistinct partial aggregation") {
- checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"),
- sql("SELECT distinct key FROM src order by key").collect().toSeq)
+ // Range partitioning uses random samples, so per-partition comparisons do not always yield
+ // the same results. Disable Comet native range partitioning.
+ withSQLConf(CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false")
+ {
+ checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"),
+ sql("SELECT distinct key FROM src order by key").collect().toSeq)
+ }
}
test("SPARK-4963 DataFrame sample on mutable row return wrong result") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 1d646f40b3e..5babe505301 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -53,25 +53,54 @@ object TestHive
new SparkContext(
System.getProperty("spark.sql.test.master", "local[1]"),
"TestSQLContext",
- new SparkConf()
- .set("spark.sql.test", "")
- .set(SQLConf.CODEGEN_FALLBACK.key, "false")
- .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
- .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
- "org.apache.spark.sql.hive.execution.PairSerDe")
- .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath)
- // SPARK-8910
- .set(UI_ENABLED, false)
- .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
- // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes
- // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764.
- .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false")
- // Disable ConvertToLocalRelation for better test coverage. Test cases built on
- // LocalRelation will exercise the optimization rules better by disabling it as
- // this rule may potentially block testing of other optimization rules such as
- // ConstantPropagation etc.
- .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)))
+ {
+ val conf = new SparkConf()
+ .set("spark.sql.test", "")
+ .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+ .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
+ .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
+ "org.apache.spark.sql.hive.execution.PairSerDe")
+ .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath)
+ // SPARK-8910
+ .set(UI_ENABLED, false)
+ .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
+ // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes
+ // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764.
+ .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false")
+ // Disable ConvertToLocalRelation for better test coverage. Test cases built on
+ // LocalRelation will exercise the optimization rules better by disabling it as
+ // this rule may potentially block testing of other optimization rules such as
+ // ConstantPropagation etc.
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+
+ if (SparkSession.isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
+
+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
+ if (v == null || !v.toBoolean) {
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
+ } else {
+ conf
+ .set("spark.comet.exec.enabled", "false")
+ .set("spark.comet.exec.shuffle.enabled", "false")
+ }
+
+ val a = System.getenv("ENABLE_COMET_ANSI_MODE")
+ if (a != null && a.toBoolean) {
+ conf
+ .set("spark.sql.ansi.enabled", "true")
+ }
+ }
+ conf
+ }
+ ))
case class TestHiveVersion(hiveClient: HiveClient)
extends TestHiveContext(TestHive.sparkContext, hiveClient)