| diff --git a/pom.xml b/pom.xml |
| index d3544881af1..9c16099090c 100644 |
| --- a/pom.xml |
| +++ b/pom.xml |
| @@ -148,6 +148,8 @@ |
| <chill.version>0.10.0</chill.version> |
| <ivy.version>2.5.1</ivy.version> |
| <oro.version>2.0.8</oro.version> |
| + <spark.version.short>3.4</spark.version.short> |
| + <comet.version>0.14.0-SNAPSHOT</comet.version> |
| <!-- |
| If you changes codahale.metrics.version, you also need to change |
| the link to metrics.dropwizard.io in docs/monitoring.md. |
| @@ -2784,6 +2786,25 @@ |
| <artifactId>arpack</artifactId> |
| <version>${netlib.ludovic.dev.version}</version> |
| </dependency> |
| + <dependency> |
| + <groupId>org.apache.datafusion</groupId> |
| + <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId> |
| + <version>${comet.version}</version> |
| + <exclusions> |
| + <exclusion> |
| + <groupId>org.apache.spark</groupId> |
| + <artifactId>spark-sql_${scala.binary.version}</artifactId> |
| + </exclusion> |
| + <exclusion> |
| + <groupId>org.apache.spark</groupId> |
| + <artifactId>spark-core_${scala.binary.version}</artifactId> |
| + </exclusion> |
| + <exclusion> |
| + <groupId>org.apache.spark</groupId> |
| + <artifactId>spark-catalyst_${scala.binary.version}</artifactId> |
| + </exclusion> |
| + </exclusions> |
| + </dependency> |
| </dependencies> |
| </dependencyManagement> |
| |
| diff --git a/sql/core/pom.xml b/sql/core/pom.xml |
| index b386d135da1..46449e3f3f1 100644 |
| --- a/sql/core/pom.xml |
| +++ b/sql/core/pom.xml |
| @@ -77,6 +77,10 @@ |
| <groupId>org.apache.spark</groupId> |
| <artifactId>spark-tags_${scala.binary.version}</artifactId> |
| </dependency> |
| + <dependency> |
| + <groupId>org.apache.datafusion</groupId> |
| + <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId> |
| + </dependency> |
| |
| <!-- |
| This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude |
| diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala |
| index c595b50950b..3abb6cb9441 100644 |
| --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala |
| +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala |
| @@ -102,7 +102,7 @@ class SparkSession private( |
| sc: SparkContext, |
| initialSessionOptions: java.util.HashMap[String, String]) = { |
| this(sc, None, None, |
| - SparkSession.applyExtensions( |
| + SparkSession.applyExtensions(sc, |
| sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), |
| new SparkSessionExtensions), initialSessionOptions.asScala.toMap) |
| } |
| @@ -1028,7 +1028,7 @@ object SparkSession extends Logging { |
| } |
| |
| loadExtensions(extensions) |
| - applyExtensions( |
| + applyExtensions(sparkContext, |
| sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), |
| extensions) |
| |
| @@ -1282,14 +1282,24 @@ object SparkSession extends Logging { |
| } |
| } |
| |
| + private def loadCometExtension(sparkContext: SparkContext): Seq[String] = { |
| + if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled)) { |
| + Seq("org.apache.comet.CometSparkSessionExtensions") |
| + } else { |
| + Seq.empty |
| + } |
| + } |
| + |
| /** |
| * Initialize extensions for given extension classnames. The classes will be applied to the |
| * extensions passed into this function. |
| */ |
| private def applyExtensions( |
| + sparkContext: SparkContext, |
| extensionConfClassNames: Seq[String], |
| extensions: SparkSessionExtensions): SparkSessionExtensions = { |
| - extensionConfClassNames.foreach { extensionConfClassName => |
| + val extensionClassNames = extensionConfClassNames ++ loadCometExtension(sparkContext) |
| + extensionClassNames.foreach { extensionConfClassName => |
| try { |
| val extensionConfClass = Utils.classForName(extensionConfClassName) |
| val extensionConf = extensionConfClass.getConstructor().newInstance() |
| @@ -1323,4 +1333,12 @@ object SparkSession extends Logging { |
| } |
| } |
| } |
| + |
| + /** |
| + * Whether Comet extension is enabled |
| + */ |
| + def isCometEnabled: Boolean = { |
| + val v = System.getenv("ENABLE_COMET") |
| + v == null || v.toBoolean |
| + } |
| } |
| diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala |
| index db587dd9868..aac7295a53d 100644 |
| --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala |
| +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala |
| @@ -18,6 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.annotation.DeveloperApi |
| +import org.apache.spark.sql.comet.CometScanExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} |
| import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec |
| import org.apache.spark.sql.execution.exchange.ReusedExchangeExec |
| @@ -67,6 +68,7 @@ private[execution] object SparkPlanInfo { |
| // dump the file scan metadata (e.g file path) to event log |
| val metadata = plan match { |
| case fileScan: FileSourceScanExec => fileScan.metadata |
| + case cometScan: CometScanExec => cometScan.metadata |
| case _ => Map[String, String]() |
| } |
| new SparkPlanInfo( |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql |
| index 7aef901da4f..f3d6e18926d 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql |
| @@ -2,3 +2,4 @@ |
| |
| --SET spark.sql.adaptive.enabled=true |
| --SET spark.sql.maxMetadataStringLength = 500 |
| +--SET spark.comet.enabled = false |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql |
| index eeb2180f7a5..afd1b5ec289 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql |
| @@ -1,5 +1,6 @@ |
| --SET spark.sql.cbo.enabled=true |
| --SET spark.sql.maxMetadataStringLength = 500 |
| +--SET spark.comet.enabled = false |
| |
| CREATE TABLE explain_temp1(a INT, b INT) USING PARQUET; |
| CREATE TABLE explain_temp2(c INT, d INT) USING PARQUET; |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql |
| index 698ca009b4f..57d774a3617 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql |
| @@ -1,6 +1,7 @@ |
| --SET spark.sql.codegen.wholeStage = true |
| --SET spark.sql.adaptive.enabled = false |
| --SET spark.sql.maxMetadataStringLength = 500 |
| +--SET spark.comet.enabled = false |
| |
| -- Test tables |
| CREATE table explain_temp1 (key int, val int) USING PARQUET; |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql |
| index 1152d77da0c..f77493f690b 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql |
| @@ -7,6 +7,9 @@ |
| |
| -- avoid bit-exact output here because operations may not be bit-exact. |
| -- SET extra_float_digits = 0; |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| |
| -- Test aggregate operator with codegen on and off. |
| --CONFIG_DIM1 spark.sql.codegen.wholeStage=true |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql |
| index 41fd4de2a09..44cd244d3b0 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql |
| @@ -5,6 +5,9 @@ |
| -- AGGREGATES [Part 3] |
| -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L352-L605 |
| |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| -- Test aggregate operator with codegen on and off. |
| --CONFIG_DIM1 spark.sql.codegen.wholeStage=true |
| --CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql |
| index 3a409eea348..38fed024c98 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql |
| @@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1 |
| -- any evens |
| SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0'); |
| |
| +-- https://github.com/apache/datafusion-comet/issues/2215 |
| +--SET spark.comet.exec.enabled=false |
| -- [SPARK-28024] Incorrect value when out of range |
| SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i; |
| |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql |
| index fac23b4a26f..2b73732c33f 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql |
| @@ -1,6 +1,10 @@ |
| -- |
| -- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group |
| -- |
| + |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| -- |
| -- INT8 |
| -- Test int8 64-bit integers. |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql |
| index 0efe0877e9b..423d3b3d76d 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql |
| @@ -1,6 +1,10 @@ |
| -- |
| -- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group |
| -- |
| + |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| -- |
| -- SELECT_HAVING |
| -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala |
| index cf40e944c09..bdd5be4f462 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala |
| @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants |
| import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.columnar._ |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -516,7 +516,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils |
| */ |
| private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { |
| assert( |
| - collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.size == expected) |
| + collect(df.queryExecution.executedPlan) { |
| + case _: ShuffleExchangeLike => 1 }.size == expected) |
| } |
| |
| test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala |
| index 1cc09c3d7fc..f031fa45c33 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala |
| @@ -27,7 +27,7 @@ import org.apache.spark.SparkException |
| import org.apache.spark.sql.execution.WholeStageCodegenExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.expressions.Window |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -755,7 +755,7 @@ class DataFrameAggregateSuite extends QueryTest |
| assert(objHashAggPlans.nonEmpty) |
| |
| val exchangePlans = collect(aggPlan) { |
| - case shuffle: ShuffleExchangeExec => shuffle |
| + case shuffle: ShuffleExchangeLike => shuffle |
| } |
| assert(exchangePlans.length == 1) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala |
| index 56e9520fdab..917932336df 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala |
| @@ -435,7 +435,9 @@ class DataFrameJoinSuite extends QueryTest |
| |
| withTempDatabase { dbName => |
| withTable(table1Name, table2Name) { |
| - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| + withSQLConf( |
| + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| + "spark.comet.enabled" -> "false") { |
| spark.range(50).write.saveAsTable(s"$dbName.$table1Name") |
| spark.range(100).write.saveAsTable(s"$dbName.$table2Name") |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala |
| index a9f69ab28a1..760ea0e9565 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala |
| @@ -39,11 +39,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri |
| import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation |
| import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| +import org.apache.spark.sql.comet.CometBroadcastExchangeExec |
| import org.apache.spark.sql.connector.FakeV2Provider |
| import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.aggregate.HashAggregateExec |
| -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} |
| +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.expressions.{Aggregator, Window} |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -1981,7 +1982,7 @@ class DataFrameSuite extends QueryTest |
| fail("Should not have back to back Aggregates") |
| } |
| atFirstAgg = true |
| - case e: ShuffleExchangeExec => atFirstAgg = false |
| + case e: ShuffleExchangeLike => atFirstAgg = false |
| case _ => |
| } |
| } |
| @@ -2305,7 +2306,7 @@ class DataFrameSuite extends QueryTest |
| checkAnswer(join, df) |
| assert( |
| collect(join.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => true }.size === 1) |
| + case _: ShuffleExchangeLike => true }.size === 1) |
| assert( |
| collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1) |
| val broadcasted = broadcast(join) |
| @@ -2313,10 +2314,12 @@ class DataFrameSuite extends QueryTest |
| checkAnswer(join2, df) |
| assert( |
| collect(join2.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => true }.size == 1) |
| + case _: ShuffleExchangeLike => true }.size == 1) |
| assert( |
| collect(join2.queryExecution.executedPlan) { |
| - case e: BroadcastExchangeExec => true }.size === 1) |
| + case e: BroadcastExchangeExec => true |
| + case _: CometBroadcastExchangeExec => true |
| + }.size === 1) |
| assert( |
| collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4) |
| } |
| @@ -2876,7 +2879,7 @@ class DataFrameSuite extends QueryTest |
| |
| // Assert that no extra shuffle introduced by cogroup. |
| val exchanges = collect(df3.queryExecution.executedPlan) { |
| - case h: ShuffleExchangeExec => h |
| + case h: ShuffleExchangeLike => h |
| } |
| assert(exchanges.size == 2) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala |
| index 433b4741979..07148eee480 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala |
| @@ -23,8 +23,9 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} |
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Lag, Literal, NonFoldableLiteral} |
| import org.apache.spark.sql.catalyst.optimizer.TransposeWindow |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning |
| +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec} |
| +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.window.WindowExec |
| import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} |
| import org.apache.spark.sql.functions._ |
| @@ -1186,10 +1187,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest |
| } |
| |
| def isShuffleExecByRequirement( |
| - plan: ShuffleExchangeExec, |
| + plan: ShuffleExchangeLike, |
| desiredClusterColumns: Seq[String]): Boolean = plan match { |
| case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS) => |
| partitionExpressionsColumns(op.expressions) === desiredClusterColumns |
| + case CometShuffleExchangeExec(op: HashPartitioning, _, _, ENSURE_REQUIREMENTS, _, _) => |
| + partitionExpressionsColumns(op.expressions) === desiredClusterColumns |
| case _ => false |
| } |
| |
| @@ -1212,7 +1215,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest |
| val shuffleByRequirement = windowed.queryExecution.executedPlan.exists { |
| case w: WindowExec => |
| w.child.exists { |
| - case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2")) |
| + case s: ShuffleExchangeLike => isShuffleExecByRequirement(s, Seq("key1", "key2")) |
| case _ => false |
| } |
| case _ => false |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala |
| index daef11ae4d6..9f3cc9181f2 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala |
| @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} |
| import org.apache.spark.sql.catalyst.util.sideBySide |
| import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} |
| +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.streaming.MemoryStream |
| import org.apache.spark.sql.expressions.UserDefinedFunction |
| import org.apache.spark.sql.functions._ |
| @@ -2254,7 +2254,7 @@ class DatasetSuite extends QueryTest |
| |
| // Assert that no extra shuffle introduced by cogroup. |
| val exchanges = collect(df3.queryExecution.executedPlan) { |
| - case h: ShuffleExchangeExec => h |
| + case h: ShuffleExchangeLike => h |
| } |
| assert(exchanges.size == 2) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala |
| index f33432ddb6f..1925aac8d97 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala |
| @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen |
| import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} |
| import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ |
| import org.apache.spark.sql.catalyst.plans.ExistenceJoin |
| +import org.apache.spark.sql.comet.CometScanExec |
| import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog} |
| import org.apache.spark.sql.execution._ |
| import org.apache.spark.sql.execution.adaptive._ |
| @@ -262,6 +263,9 @@ abstract class DynamicPartitionPruningSuiteBase |
| case s: BatchScanExec => s.runtimeFilters.collect { |
| case d: DynamicPruningExpression => d.child |
| } |
| + case s: CometScanExec => s.partitionFilters.collect { |
| + case d: DynamicPruningExpression => d.child |
| + } |
| case _ => Nil |
| } |
| } |
| @@ -755,7 +759,8 @@ abstract class DynamicPartitionPruningSuiteBase |
| } |
| } |
| |
| - test("partition pruning in broadcast hash joins") { |
| + test("partition pruning in broadcast hash joins", |
| + IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #1737")) { |
| Given("disable broadcast pruning and disable subquery duplication") |
| withSQLConf( |
| SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", |
| @@ -1027,7 +1032,8 @@ abstract class DynamicPartitionPruningSuiteBase |
| } |
| } |
| |
| - test("avoid reordering broadcast join keys to match input hash partitioning") { |
| + test("avoid reordering broadcast join keys to match input hash partitioning", |
| + IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| withTable("large", "dimTwo", "dimThree") { |
| @@ -1215,7 +1221,8 @@ abstract class DynamicPartitionPruningSuiteBase |
| } |
| |
| test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " + |
| - "canonicalization and exchange reuse") { |
| + "canonicalization and exchange reuse", |
| + IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val df = sql( |
| @@ -1423,7 +1430,8 @@ abstract class DynamicPartitionPruningSuiteBase |
| } |
| } |
| |
| - test("SPARK-34637: DPP side broadcast query stage is created firstly") { |
| + test("SPARK-34637: DPP side broadcast query stage is created firstly", |
| + IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { |
| val df = sql( |
| """ WITH v as ( |
| @@ -1729,6 +1737,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat |
| case s: BatchScanExec => |
| // we use f1 col for v2 tables due to schema pruning |
| s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) |
| + case s: CometScanExec => |
| + s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid"))) |
| case _ => false |
| } |
| assert(scanOption.isDefined) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala |
| index a6b295578d6..91acca4306f 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala |
| @@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite |
| } |
| } |
| |
| - test("Explain formatted output for scan operator for datasource V2") { |
| + test("Explain formatted output for scan operator for datasource V2", |
| + IgnoreComet("Comet explain output is different")) { |
| withTempDir { dir => |
| Seq("parquet", "orc", "csv", "json").foreach { fmt => |
| val basePath = dir.getCanonicalPath + "/" + fmt |
| @@ -541,7 +542,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite |
| } |
| } |
| |
| -class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { |
| +// Ignored when Comet is enabled. Comet changes expected query plans. |
| +class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite |
| + with IgnoreCometSuite { |
| import testImplicits._ |
| |
| test("SPARK-35884: Explain Formatted") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala |
| index 2796b1cf154..52438178a0e 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala |
| @@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} |
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal} |
| import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} |
| import org.apache.spark.sql.catalyst.plans.logical.Filter |
| +import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeScanExec, CometScanExec, CometSortMergeJoinExec} |
| import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.datasources.FilePartition |
| @@ -815,6 +816,7 @@ class FileBasedDataSourceSuite extends QueryTest |
| assert(bJoinExec.isEmpty) |
| val smJoinExec = collect(joinedDF.queryExecution.executedPlan) { |
| case smJoin: SortMergeJoinExec => smJoin |
| + case smJoin: CometSortMergeJoinExec => smJoin |
| } |
| assert(smJoinExec.nonEmpty) |
| } |
| @@ -875,6 +877,7 @@ class FileBasedDataSourceSuite extends QueryTest |
| |
| val fileScan = df.queryExecution.executedPlan collectFirst { |
| case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f |
| + case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _), _, _) => f |
| } |
| assert(fileScan.nonEmpty) |
| assert(fileScan.get.partitionFilters.nonEmpty) |
| @@ -916,6 +919,7 @@ class FileBasedDataSourceSuite extends QueryTest |
| |
| val fileScan = df.queryExecution.executedPlan collectFirst { |
| case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f |
| + case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _), _, _) => f |
| } |
| assert(fileScan.nonEmpty) |
| assert(fileScan.get.partitionFilters.isEmpty) |
| @@ -1100,6 +1104,9 @@ class FileBasedDataSourceSuite extends QueryTest |
| val filters = df.queryExecution.executedPlan.collect { |
| case f: FileSourceScanLike => f.dataFilters |
| case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters |
| + case b: CometScanExec => b.dataFilters |
| + case b: CometNativeScanExec => b.dataFilters |
| + case b: CometBatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters |
| }.flatten |
| assert(filters.contains(GreaterThan(scan.logicalPlan.output.head, Literal(5L)))) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala |
| new file mode 100644 |
| index 00000000000..5691536c114 |
| --- /dev/null |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala |
| @@ -0,0 +1,45 @@ |
| +/* |
| + * Licensed to the Apache Software Foundation (ASF) under one or more |
| + * contributor license agreements. See the NOTICE file distributed with |
| + * this work for additional information regarding copyright ownership. |
| + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| + * (the "License"); you may not use this file except in compliance with |
| + * the License. You may obtain a copy of the License at |
| + * |
| + * http://www.apache.org/licenses/LICENSE-2.0 |
| + * |
| + * Unless required by applicable law or agreed to in writing, software |
| + * distributed under the License is distributed on an "AS IS" BASIS, |
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| + * See the License for the specific language governing permissions and |
| + * limitations under the License. |
| + */ |
| + |
| +package org.apache.spark.sql |
| + |
| +import org.scalactic.source.Position |
| +import org.scalatest.Tag |
| + |
| +import org.apache.spark.sql.test.SQLTestUtils |
| + |
| +/** |
| + * Tests with this tag will be ignored when Comet is enabled (e.g., via `ENABLE_COMET`). |
| + */ |
| +case class IgnoreComet(reason: String) extends Tag("DisableComet") |
| +case class IgnoreCometNativeIcebergCompat(reason: String) extends Tag("DisableComet") |
| +case class IgnoreCometNativeDataFusion(reason: String) extends Tag("DisableComet") |
| +case class IgnoreCometNativeScan(reason: String) extends Tag("DisableComet") |
| + |
| +/** |
| + * Helper trait that disables Comet for all tests regardless of default config values. |
| + */ |
| +trait IgnoreCometSuite extends SQLTestUtils { |
| + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) |
| + (implicit pos: Position): Unit = { |
| + if (isCometEnabled) { |
| + ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun) |
| + } else { |
| + super.test(testName, testTags: _*)(testFun) |
| + } |
| + } |
| +} |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala |
| index fda442eeef0..1b69e4f280e 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala |
| @@ -468,7 +468,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp |
| } |
| |
| test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " + |
| - "on the same column") { |
| + "on the same column", |
| + IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) { |
| withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { |
| assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " + |
| @@ -477,7 +478,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp |
| } |
| |
| test("Runtime bloom filter join: add bloom filter if dpp filter exists on " + |
| - "a different column") { |
| + "a different column", |
| + IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) { |
| withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { |
| assertRewroteWithBloomFilter("select * from bf5part join bf2 on " + |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala |
| index 1792b4c32eb..1616e6f39bd 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala |
| @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide |
| import org.apache.spark.sql.catalyst.plans.PlanTest |
| import org.apache.spark.sql.catalyst.plans.logical._ |
| import org.apache.spark.sql.catalyst.rules.RuleExecutor |
| +import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.joins._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP |
| val executedPlan = df.queryExecution.executedPlan |
| val shuffleHashJoins = collect(executedPlan) { |
| case s: ShuffledHashJoinExec => s |
| + case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec] |
| } |
| assert(shuffleHashJoins.size == 1) |
| assert(shuffleHashJoins.head.buildSide == buildSide) |
| @@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP |
| val executedPlan = df.queryExecution.executedPlan |
| val shuffleMergeJoins = collect(executedPlan) { |
| case s: SortMergeJoinExec => s |
| + case c: CometSortMergeJoinExec => c |
| } |
| assert(shuffleMergeJoins.size == 1) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala |
| index 7f062bfb899..0ed85486e80 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala |
| @@ -30,7 +30,8 @@ import org.apache.spark.sql.catalyst.TableIdentifier |
| import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation |
| import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} |
| import org.apache.spark.sql.catalyst.plans.logical.Filter |
| -import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} |
| +import org.apache.spark.sql.comet._ |
| +import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, FilterExec, InputAdapter, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.joins._ |
| @@ -740,7 +741,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| } |
| } |
| |
| - test("test SortMergeJoin (with spill)") { |
| + test("test SortMergeJoin (with spill)", |
| + IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) { |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", |
| SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", |
| SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { |
| @@ -866,10 +868,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val physical = df.queryExecution.sparkPlan |
| val physicalJoins = physical.collect { |
| case j: SortMergeJoinExec => j |
| + case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec] |
| } |
| val executed = df.queryExecution.executedPlan |
| val executedJoins = collect(executed) { |
| case j: SortMergeJoinExec => j |
| + case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec] |
| } |
| // This only applies to the above tested queries, in which a child SortMergeJoin always |
| // contains the SortOrder required by its parent SortMergeJoin. Thus, SortExec should never |
| @@ -1115,9 +1119,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) |
| .groupBy($"k1").count() |
| .queryExecution.executedPlan |
| - assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) |
| // No extra shuffle before aggregate |
| - assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2) |
| + assert(collect(plan) { |
| + case _: ShuffleExchangeLike => true }.size === 2) |
| }) |
| } |
| |
| @@ -1134,10 +1140,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) |
| .queryExecution |
| .executedPlan |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) |
| assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1) |
| // No extra sort before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 3) |
| + assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 3) |
| }) |
| |
| // Test shuffled hash join |
| @@ -1147,10 +1154,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) |
| .queryExecution |
| .executedPlan |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) |
| - assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) |
| + assert(collect(plan) { |
| + case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) |
| // No extra sort before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 3) |
| + assert(collect(plan) { |
| + case _: SortExec | _: CometSortExec => true }.size === 3) |
| }) |
| } |
| |
| @@ -1241,12 +1251,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| inputDFs.foreach { case (df1, df2, joinExprs) => |
| val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full") |
| assert(collect(smjDF.queryExecution.executedPlan) { |
| - case _: SortMergeJoinExec => true }.size === 1) |
| + case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 1) |
| val smjResult = smjDF.collect() |
| |
| val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full") |
| assert(collect(shjDF.queryExecution.executedPlan) { |
| - case _: ShuffledHashJoinExec => true }.size === 1) |
| + case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) |
| // Same result between shuffled hash join and sort merge join |
| checkAnswer(shjDF, smjResult) |
| } |
| @@ -1282,18 +1292,26 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| } |
| |
| // Test shuffled hash join |
| - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { |
| + withSQLConf("spark.comet.enabled" -> "true", |
| + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { |
| val shjCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) |
| assert(shjCodegenDF.queryExecution.executedPlan.collect { |
| case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true |
| case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true |
| + case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometHashJoinExec))) => |
| + true |
| + case WholeStageCodegenExec(ColumnarToRowExec( |
| + InputAdapter(CometProjectExec(_, _, _, _, _: CometHashJoinExec, _)))) => true |
| + case _: CometHashJoinExec => true |
| }.size === 1) |
| checkAnswer(shjCodegenDF, Seq.empty) |
| |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { |
| val shjNonCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) |
| assert(shjNonCodegenDF.queryExecution.executedPlan.collect { |
| - case _: ShuffledHashJoinExec => true }.size === 1) |
| + case _: ShuffledHashJoinExec => true |
| + case _: CometHashJoinExec => true |
| + }.size === 1) |
| checkAnswer(shjNonCodegenDF, Seq.empty) |
| } |
| } |
| @@ -1341,7 +1359,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan |
| assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) |
| // Have shuffle before aggregation |
| - assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: ShuffleExchangeLike => true }.size === 1) |
| } |
| |
| def getJoinQuery(selectExpr: String, joinType: String): String = { |
| @@ -1370,9 +1389,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| } |
| val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan |
| assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| + }.size === 3) |
| // No extra sort on left side before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 5) |
| + assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 5) |
| } |
| |
| // Test output ordering is not preserved |
| @@ -1381,9 +1403,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0" |
| val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan |
| assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| + }.size === 3) |
| // Have sort on left side before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 6) |
| + assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 6) |
| } |
| |
| // Test singe partition |
| @@ -1393,7 +1418,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2 |
| |""".stripMargin) |
| val plan = fullJoinDF.queryExecution.executedPlan |
| - assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1) |
| + assert(collect(plan) { |
| + case _: ShuffleExchangeLike => true}.size == 1) |
| checkAnswer(fullJoinDF, Row(100)) |
| } |
| } |
| @@ -1438,6 +1464,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| Seq(semiJoinDF, antiJoinDF).foreach { df => |
| assert(collect(df.queryExecution.executedPlan) { |
| case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true |
| + case j: CometHashJoinExec |
| + if j.originalPlan.asInstanceOf[ShuffledHashJoinExec].ignoreDuplicatedKey == |
| + ignoreDuplicatedKey => true |
| }.size == 1) |
| } |
| } |
| @@ -1482,14 +1511,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| |
| test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") { |
| def check(plan: SparkPlan): Unit = { |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| + }.size === 1) |
| } |
| dupStreamSideColTest("MERGE", check) |
| } |
| |
| test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SHJ)") { |
| def check(plan: SparkPlan): Unit = { |
| - assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: ShuffledHashJoinExec => true |
| + case _: CometHashJoinExec => true |
| + }.size === 1) |
| } |
| dupStreamSideColTest("SHUFFLE_HASH", check) |
| } |
| @@ -1605,7 +1640,8 @@ class ThreadLeakInSortMergeJoinSuite |
| sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20)) |
| } |
| |
| - test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)") { |
| + test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)", |
| + IgnoreComet("Comet SMJ doesn't spill yet")) { |
| |
| withSQLConf( |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala |
| index b5b34922694..a72403780c4 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala |
| @@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest |
| * }}} |
| */ |
| // scalastyle:on line.size.limit |
| -trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { |
| +trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite with IgnoreCometSuite { |
| |
| protected val baseResourcePath = { |
| // use the same way as `SQLQueryTestSuite` to get the resource path |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala |
| index 525d97e4998..843f0472c23 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala |
| @@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) |
| } |
| |
| - test("external sorting updates peak execution memory") { |
| + test("external sorting updates peak execution memory", |
| + IgnoreComet("TODO: native CometSort does not update peak execution memory")) { |
| AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { |
| sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() |
| } |
| @@ -4429,7 +4430,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| } |
| |
| test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" + |
| - " when WSCG is off") { |
| + " when WSCG is off", |
| + IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", |
| SQLConf.ANSI_ENABLED.key -> "true") { |
| withTable("t") { |
| @@ -4450,7 +4452,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| } |
| |
| test("SPARK-39175: Query context of Cast should be serialized to executors" + |
| - " when WSCG is off") { |
| + " when WSCG is off", |
| + IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", |
| SQLConf.ANSI_ENABLED.key -> "true") { |
| withTable("t") { |
| @@ -4467,14 +4470,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| val msg = intercept[SparkException] { |
| sql(query).collect() |
| }.getMessage |
| - assert(msg.contains(query)) |
| + if (!isCometEnabled) { |
| + // Comet's error message does not include the original SQL query |
| + // https://github.com/apache/datafusion-comet/issues/2215 |
| + assert(msg.contains(query)) |
| + } |
| } |
| } |
| } |
| } |
| |
| test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " + |
| - "be serialized to executors when WSCG is off") { |
| + "be serialized to executors when WSCG is off", |
| + IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", |
| SQLConf.ANSI_ENABLED.key -> "true") { |
| withTable("t") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala |
| index 48ad10992c5..51d1ee65422 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala |
| @@ -221,6 +221,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper { |
| withSession(extensions) { session => |
| session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true) |
| session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") |
| + // https://github.com/apache/datafusion-comet/issues/1197 |
| + session.conf.set("spark.comet.enabled", false) |
| assert(session.sessionState.columnarRules.contains( |
| MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) |
| import session.sqlContext.implicits._ |
| @@ -279,6 +281,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper { |
| } |
| withSession(extensions) { session => |
| session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE) |
| + // https://github.com/apache/datafusion-comet/issues/1197 |
| + session.conf.set("spark.comet.enabled", false) |
| assert(session.sessionState.columnarRules.contains( |
| MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) |
| import session.sqlContext.implicits._ |
| @@ -317,6 +321,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper { |
| val session = SparkSession.builder() |
| .master("local[1]") |
| .config(COLUMN_BATCH_SIZE.key, 2) |
| + // https://github.com/apache/datafusion-comet/issues/1197 |
| + .config("spark.comet.enabled", false) |
| .withExtensions { extensions => |
| extensions.injectColumnar(session => |
| MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala |
| index 18123a4d6ec..fbe4c766eee 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala |
| @@ -17,6 +17,8 @@ |
| |
| package org.apache.spark.sql |
| |
| +import org.apache.comet.CometConf |
| + |
| import org.apache.spark.{SPARK_DOC_ROOT, SparkRuntimeException} |
| import org.apache.spark.sql.catalyst.expressions.Cast._ |
| import org.apache.spark.sql.catalyst.expressions.TryToNumber |
| @@ -133,29 +135,31 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { |
| } |
| |
| test("string regex_replace / regex_extract") { |
| - val df = Seq( |
| - ("100-200", "(\\d+)-(\\d+)", "300"), |
| - ("100-200", "(\\d+)-(\\d+)", "400"), |
| - ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") |
| - |
| - checkAnswer( |
| - df.select( |
| - regexp_replace($"a", "(\\d+)", "num"), |
| - regexp_replace($"a", $"b", $"c"), |
| - regexp_extract($"a", "(\\d+)-(\\d+)", 1)), |
| - Row("num-num", "300", "100") :: Row("num-num", "400", "100") :: |
| - Row("num-num", "400-400", "100") :: Nil) |
| + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { |
| + val df = Seq( |
| + ("100-200", "(\\d+)-(\\d+)", "300"), |
| + ("100-200", "(\\d+)-(\\d+)", "400"), |
| + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") |
| |
| - // for testing the mutable state of the expression in code gen. |
| - // This is a hack way to enable the codegen, thus the codegen is enable by default, |
| - // it will still use the interpretProjection if projection followed by a LocalRelation, |
| - // hence we add a filter operator. |
| - // See the optimizer rule `ConvertToLocalRelation` |
| - checkAnswer( |
| - df.filter("isnotnull(a)").selectExpr( |
| - "regexp_replace(a, b, c)", |
| - "regexp_extract(a, b, 1)"), |
| - Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) |
| + checkAnswer( |
| + df.select( |
| + regexp_replace($"a", "(\\d+)", "num"), |
| + regexp_replace($"a", $"b", $"c"), |
| + regexp_extract($"a", "(\\d+)-(\\d+)", 1)), |
| + Row("num-num", "300", "100") :: Row("num-num", "400", "100") :: |
| + Row("num-num", "400-400", "100") :: Nil) |
| + |
| + // for testing the mutable state of the expression in code gen. |
| + // This is a hack way to enable the codegen, thus the codegen is enable by default, |
| + // it will still use the interpretProjection if projection followed by a LocalRelation, |
| + // hence we add a filter operator. |
| + // See the optimizer rule `ConvertToLocalRelation` |
| + checkAnswer( |
| + df.filter("isnotnull(a)").selectExpr( |
| + "regexp_replace(a, b, c)", |
| + "regexp_extract(a, b, 1)"), |
| + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) |
| + } |
| } |
| |
| test("non-matching optional group") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala |
| index 75eabcb96f2..7c0bbd71551 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala |
| @@ -21,10 +21,11 @@ import scala.collection.mutable.ArrayBuffer |
| |
| import org.apache.spark.sql.catalyst.expressions.SubqueryExpression |
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort, Union} |
| +import org.apache.spark.sql.comet.CometScanExec |
| import org.apache.spark.sql.execution._ |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} |
| import org.apache.spark.sql.execution.datasources.FileScanRDD |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| @@ -1543,6 +1544,12 @@ class SubquerySuite extends QueryTest |
| fs.inputRDDs().forall( |
| _.asInstanceOf[FileScanRDD].filePartitions.forall( |
| _.files.forall(_.urlEncodedPath.contains("p=0")))) |
| + case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter( |
| + fs @ CometScanExec(_, _, _, _, partitionFilters, _, _, _, _, _, _)))) => |
| + partitionFilters.exists(ExecSubqueryExpression.hasSubquery) && |
| + fs.inputRDDs().forall( |
| + _.asInstanceOf[FileScanRDD].filePartitions.forall( |
| + _.files.forall(_.urlEncodedPath.contains("p=0")))) |
| case _ => false |
| }) |
| } |
| @@ -2108,7 +2115,7 @@ class SubquerySuite extends QueryTest |
| |
| df.collect() |
| val exchanges = collect(df.queryExecution.executedPlan) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(exchanges.size === 1) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala |
| index 02990a7a40d..bddf5e1ccc2 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala |
| @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._ |
| |
| import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} |
| import org.apache.spark.sql.catalyst.InternalRow |
| +import org.apache.spark.sql.comet.CometSortExec |
| import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} |
| import org.apache.spark.sql.connector.catalog.TableCapability._ |
| import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} |
| @@ -33,7 +34,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, |
| import org.apache.spark.sql.execution.SortExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} |
| -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} |
| +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector |
| import org.apache.spark.sql.expressions.Window |
| import org.apache.spark.sql.functions._ |
| @@ -268,13 +269,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS |
| val groupByColJ = df.groupBy($"j").agg(sum($"i")) |
| checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) |
| assert(collectFirst(groupByColJ.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined) |
| |
| val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*")) |
| checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) |
| assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined) |
| } |
| } |
| @@ -334,10 +335,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS |
| |
| val (shuffleExpected, sortExpected) = groupByExpects |
| assert(collectFirst(groupBy.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined === shuffleExpected) |
| assert(collectFirst(groupBy.queryExecution.executedPlan) { |
| case e: SortExec => e |
| + case c: CometSortExec => c |
| }.isDefined === sortExpected) |
| } |
| |
| @@ -352,10 +354,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS |
| |
| val (shuffleExpected, sortExpected) = windowFuncExpects |
| assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined === shuffleExpected) |
| assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { |
| case e: SortExec => e |
| + case c: CometSortExec => c |
| }.isDefined === sortExpected) |
| } |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala |
| index cfc8b2cc845..c4be7eb3731 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala |
| @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer |
| import org.apache.spark.SparkConf |
| import org.apache.spark.sql.{AnalysisException, QueryTest} |
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan |
| +import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} |
| import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} |
| import org.apache.spark.sql.connector.read.ScanBuilder |
| import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} |
| @@ -184,7 +185,11 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { |
| val df = spark.read.format(format).load(path.getCanonicalPath) |
| checkAnswer(df, inputData.toDF()) |
| assert( |
| - df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec])) |
| + df.queryExecution.executedPlan.exists { |
| + case _: FileSourceScanExec | _: CometScanExec | _: CometNativeScanExec => true |
| + case _ => false |
| + } |
| + ) |
| } |
| } finally { |
| spark.listenerManager.unregister(listener) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala |
| index cf76f6ca32c..f454128af06 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala |
| @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row} |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression} |
| import org.apache.spark.sql.catalyst.plans.physical |
| +import org.apache.spark.sql.comet.CometSortMergeJoinExec |
| import org.apache.spark.sql.connector.catalog.Identifier |
| import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog |
| import org.apache.spark.sql.connector.catalog.functions._ |
| @@ -31,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._ |
| import org.apache.spark.sql.execution.SparkPlan |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.joins.SortMergeJoinExec |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.internal.SQLConf._ |
| @@ -279,13 +280,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { |
| Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) |
| } |
| |
| - private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { |
| + private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = { |
| // here we skip collecting shuffle operators that are not associated with SMJ |
| collect(plan) { |
| case s: SortMergeJoinExec => s |
| + case c: CometSortMergeJoinExec => c.originalPlan |
| }.flatMap(smj => |
| collect(smj) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| }) |
| } |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala |
| index c0ec8a58bd5..4e8bc6ed3c5 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala |
| @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.permission.FsPermission |
| import org.mockito.Mockito.{mock, spy, when} |
| |
| import org.apache.spark._ |
| -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode} |
| +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, IgnoreComet, QueryTest, Row, SaveMode} |
| import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ |
| import org.apache.spark.sql.catalyst.util.BadRecordException |
| import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} |
| @@ -248,7 +248,8 @@ class QueryExecutionErrorsSuite |
| } |
| |
| test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " + |
| - "compatibility with Spark 2.4/3.2 in reading/writing dates") { |
| + "compatibility with Spark 2.4/3.2 in reading/writing dates", |
| + IgnoreComet("Comet doesn't completely support datetime rebase mode yet")) { |
| |
| // Fail to read ancient datetime values. |
| withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala |
| index 418ca3430bb..eb8267192f8 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala |
| @@ -23,7 +23,7 @@ import scala.util.Random |
| import org.apache.hadoop.fs.Path |
| |
| import org.apache.spark.SparkConf |
| -import org.apache.spark.sql.{DataFrame, QueryTest} |
| +import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest} |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -195,7 +195,7 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest { |
| } |
| } |
| |
| - test("FileScan description") { |
| + test("FileScan description", IgnoreComet("Comet doesn't use BatchScan")) { |
| Seq("json", "orc", "parquet").foreach { format => |
| withTempPath { path => |
| val dir = path.getCanonicalPath |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala |
| index 743ec41dbe7..9f30d6c8e04 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala |
| @@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv |
| case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child) |
| case p: ProjectExec => isScanPlanTree(p.child) |
| case f: FilterExec => isScanPlanTree(f.child) |
| + // Comet produces scan plan tree like: |
| + // ColumnarToRow |
| + // +- ReusedExchange |
| + case _: ReusedExchangeExec => false |
| case _: LeafExecNode => true |
| case _ => false |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala |
| index 4b3d3a4b805..56e1e0e6f16 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala |
| @@ -18,7 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.rdd.RDD |
| -import org.apache.spark.sql.{execution, DataFrame, Row} |
| +import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row} |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions._ |
| import org.apache.spark.sql.catalyst.plans._ |
| @@ -35,7 +35,9 @@ import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.sql.types._ |
| |
| -class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { |
| +// Ignore this suite when Comet is enabled. This suite tests the Spark planner and Comet planner |
| +// comes out with too many difference. Simply ignoring this suite for now. |
| +class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper with IgnoreCometSuite { |
| import testImplicits._ |
| |
| setupTestData() |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala |
| index 9e9d717db3b..ec73082f458 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala |
| @@ -17,7 +17,10 @@ |
| |
| package org.apache.spark.sql.execution |
| |
| -import org.apache.spark.sql.{DataFrame, QueryTest, Row} |
| +import org.apache.comet.CometConf |
| + |
| +import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest, Row} |
| +import org.apache.spark.sql.comet.CometProjectExec |
| import org.apache.spark.sql.connector.SimpleWritableDataSource |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -34,7 +37,10 @@ abstract class RemoveRedundantProjectsSuiteBase |
| private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = { |
| withClue(df.queryExecution) { |
| val plan = df.queryExecution.executedPlan |
| - val actual = collectWithSubqueries(plan) { case p: ProjectExec => p }.size |
| + val actual = collectWithSubqueries(plan) { |
| + case p: ProjectExec => p |
| + case p: CometProjectExec => p |
| + }.size |
| assert(actual == expected) |
| } |
| } |
| @@ -112,7 +118,8 @@ abstract class RemoveRedundantProjectsSuiteBase |
| assertProjectExec(query, 1, 3) |
| } |
| |
| - test("join with ordering requirement") { |
| + test("join with ordering requirement", |
| + IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) { |
| val query = "select * from (select key, a, c, b from testView) as t1 join " + |
| "(select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50" |
| assertProjectExec(query, 2, 2) |
| @@ -134,12 +141,21 @@ abstract class RemoveRedundantProjectsSuiteBase |
| val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0") |
| df.collect() |
| val plan = df.queryExecution.executedPlan |
| - val numProjects = collectWithSubqueries(plan) { case p: ProjectExec => p }.length |
| + val numProjects = collectWithSubqueries(plan) { |
| + case p: ProjectExec => p |
| + case p: CometProjectExec => p |
| + }.length |
| + |
| + // Comet-specific change to get original Spark plan before applying |
| + // a transformation to add a new ProjectExec |
| + var sparkPlan: SparkPlan = null |
| + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { |
| + val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0") |
| + df.collect() |
| + sparkPlan = df.queryExecution.executedPlan |
| + } |
| |
| - // Create a new plan that reverse the GenerateExec output and add a new ProjectExec between |
| - // GenerateExec and its child. This is to test if the ProjectExec is removed, the output of |
| - // the query will be incorrect. |
| - val newPlan = stripAQEPlan(plan) transform { |
| + val newPlan = stripAQEPlan(sparkPlan) transform { |
| case g @ GenerateExec(_, requiredChildOutput, _, _, child) => |
| g.copy(requiredChildOutput = requiredChildOutput.reverse, |
| child = ProjectExec(requiredChildOutput.reverse, child)) |
| @@ -151,6 +167,7 @@ abstract class RemoveRedundantProjectsSuiteBase |
| // The manually added ProjectExec node shouldn't be removed. |
| assert(collectWithSubqueries(newExecutedPlan) { |
| case p: ProjectExec => p |
| + case p: CometProjectExec => p |
| }.size == numProjects + 1) |
| |
| // Check the original plan's output and the new plan's output are the same. |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala |
| index 30ce940b032..0d3f6c6c934 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala |
| @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution |
| |
| import org.apache.spark.sql.{DataFrame, QueryTest} |
| import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning} |
| +import org.apache.spark.sql.comet.CometSortExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.execution.joins.ShuffledJoin |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase |
| |
| private def checkNumSorts(df: DataFrame, count: Int): Unit = { |
| val plan = df.queryExecution.executedPlan |
| - assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count) |
| + assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count) |
| } |
| |
| private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala |
| index 47679ed7865..9ffbaecb98e 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala |
| @@ -18,6 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.sql.{DataFrame, QueryTest} |
| +import org.apache.spark.sql.comet.CometHashAggregateExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase |
| private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = { |
| val plan = df.queryExecution.executedPlan |
| assert(collectWithSubqueries(plan) { |
| - case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s |
| + case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s |
| }.length == hashAggCount) |
| assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala |
| index eec396b2e39..bf3f1c769d6 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala |
| @@ -18,7 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.TestUtils.assertSpilled |
| -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} |
| +import org.apache.spark.sql.{AnalysisException, IgnoreComet, QueryTest, Row} |
| import org.apache.spark.sql.internal.SQLConf.{WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD, WINDOW_EXEC_BUFFER_SPILL_THRESHOLD} |
| import org.apache.spark.sql.test.SharedSparkSession |
| |
| @@ -470,7 +470,7 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSparkSession { |
| Row(1, 3, null) :: Row(2, null, 4) :: Nil) |
| } |
| |
| - test("test with low buffer spill threshold") { |
| + test("test with low buffer spill threshold", IgnoreComet("Comet does not support spilling")) { |
| val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") |
| nums.createOrReplaceTempView("nums") |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala |
| index b14f4a405f6..90bed10eca9 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala |
| @@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} |
| import org.apache.spark.sql.catalyst.plans.logical.Deduplicate |
| +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometNativeColumnarToRowExec} |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| @@ -131,7 +132,11 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession { |
| spark.range(1).write.parquet(path.getAbsolutePath) |
| val df = spark.read.parquet(path.getAbsolutePath) |
| val columnarToRowExec = |
| - df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get |
| + df.queryExecution.executedPlan.collectFirst { |
| + case p: ColumnarToRowExec => p |
| + case p: CometColumnarToRowExec => p |
| + case p: CometNativeColumnarToRowExec => p |
| + }.get |
| try { |
| spark.range(1).foreach { _ => |
| columnarToRowExec.canonicalized |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala |
| index ac710c32296..2854b433dd3 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala |
| @@ -17,7 +17,7 @@ |
| |
| package org.apache.spark.sql.execution |
| |
| -import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} |
| +import org.apache.spark.sql.{Dataset, IgnoreCometSuite, QueryTest, Row, SaveMode} |
| import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} |
| import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite |
| import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} |
| @@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.sql.types.{IntegerType, StringType, StructType} |
| |
| // Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec |
| -class WholeStageCodegenSuite extends QueryTest with SharedSparkSession |
| +class WholeStageCodegenSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite |
| with DisableAdaptiveExecutionSuite { |
| |
| import testImplicits._ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala |
| index 593bd7bb4ba..32af28b0238 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala |
| @@ -26,9 +26,11 @@ import org.scalatest.time.SpanSugar._ |
| |
| import org.apache.spark.SparkException |
| import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} |
| -import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} |
| +import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row, SparkSession, Strategy} |
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} |
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} |
| +import org.apache.spark.sql.comet._ |
| +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec |
| import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} |
| import org.apache.spark.sql.execution.aggregate.BaseAggregateExec |
| import org.apache.spark.sql.execution.command.DataWritingCommandExec |
| @@ -104,6 +106,7 @@ class AdaptiveQueryExecSuite |
| private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { |
| collect(plan) { |
| case j: BroadcastHashJoinExec => j |
| + case j: CometBroadcastHashJoinExec => j.originalPlan.asInstanceOf[BroadcastHashJoinExec] |
| } |
| } |
| |
| @@ -116,30 +119,39 @@ class AdaptiveQueryExecSuite |
| private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { |
| collect(plan) { |
| case j: SortMergeJoinExec => j |
| + case j: CometSortMergeJoinExec => |
| + assert(j.originalPlan.isInstanceOf[SortMergeJoinExec]) |
| + j.originalPlan.asInstanceOf[SortMergeJoinExec] |
| } |
| } |
| |
| private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = { |
| collect(plan) { |
| case j: ShuffledHashJoinExec => j |
| + case j: CometHashJoinExec => j.originalPlan.asInstanceOf[ShuffledHashJoinExec] |
| } |
| } |
| |
| private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { |
| collect(plan) { |
| case j: BaseJoinExec => j |
| + case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] |
| + case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] |
| + case c: CometBroadcastHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] |
| } |
| } |
| |
| private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = { |
| collect(plan) { |
| case s: SortExec => s |
| + case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] |
| } |
| } |
| |
| private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { |
| collect(plan) { |
| case agg: BaseAggregateExec => agg |
| + case agg: CometHashAggregateExec => agg.originalPlan.asInstanceOf[BaseAggregateExec] |
| } |
| } |
| |
| @@ -176,6 +188,7 @@ class AdaptiveQueryExecSuite |
| val parts = rdd.partitions |
| assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) |
| } |
| + |
| assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead)) |
| } |
| |
| @@ -184,7 +197,7 @@ class AdaptiveQueryExecSuite |
| val plan = df.queryExecution.executedPlan |
| assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) |
| val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffle.size == 1) |
| assert(shuffle(0).outputPartitioning.numPartitions == numPartition) |
| @@ -200,7 +213,8 @@ class AdaptiveQueryExecSuite |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| - checkNumLocalShuffleReads(adaptivePlan) |
| + // Comet shuffle changes shuffle metrics |
| + // checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| @@ -227,7 +241,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Reuse the parallelism of coalesced shuffle in local shuffle read") { |
| + test("Reuse the parallelism of coalesced shuffle in local shuffle read", |
| + IgnoreComet("Comet shuffle changes shuffle partition size")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| @@ -259,7 +274,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Reuse the default parallelism in local shuffle read") { |
| + test("Reuse the default parallelism in local shuffle read", |
| + IgnoreComet("Comet shuffle changes shuffle partition size")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| @@ -273,7 +289,8 @@ class AdaptiveQueryExecSuite |
| val localReads = collect(adaptivePlan) { |
| case read: AQEShuffleReadExec if read.isLocalRead => read |
| } |
| - assert(localReads.length == 2) |
| + // Comet shuffle changes shuffle metrics |
| + assert(localReads.length == 1) |
| val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] |
| val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] |
| // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 |
| @@ -298,7 +315,9 @@ class AdaptiveQueryExecSuite |
| .groupBy($"a").count() |
| checkAnswer(testDf, Seq()) |
| val plan = testDf.queryExecution.executedPlan |
| - assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) |
| + assert(find(plan) { case p => |
| + p.isInstanceOf[SortMergeJoinExec] || p.isInstanceOf[CometSortMergeJoinExec] |
| + }.isDefined) |
| val coalescedReads = collect(plan) { |
| case r: AQEShuffleReadExec => r |
| } |
| @@ -312,7 +331,9 @@ class AdaptiveQueryExecSuite |
| .groupBy($"a").count() |
| checkAnswer(testDf, Seq()) |
| val plan = testDf.queryExecution.executedPlan |
| - assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) |
| + assert(find(plan) { case p => |
| + p.isInstanceOf[BroadcastHashJoinExec] || p.isInstanceOf[CometBroadcastHashJoinExec] |
| + }.isDefined) |
| val coalescedReads = collect(plan) { |
| case r: AQEShuffleReadExec => r |
| } |
| @@ -322,7 +343,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Scalar subquery") { |
| + test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -337,7 +358,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Scalar subquery in later stages") { |
| + test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -353,7 +374,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("multiple joins") { |
| + test("multiple joins", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -398,7 +419,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("multiple joins with aggregate") { |
| + test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -443,7 +464,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("multiple joins with aggregate 2") { |
| + test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { |
| @@ -489,7 +510,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Exchange reuse") { |
| + test("Exchange reuse", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -508,7 +529,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Exchange reuse with subqueries") { |
| + test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -539,7 +560,9 @@ class AdaptiveQueryExecSuite |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| - checkNumLocalShuffleReads(adaptivePlan) |
| + // Comet shuffle changes shuffle metrics, |
| + // so we can't check the number of local shuffle reads. |
| + // checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.nonEmpty) |
| @@ -560,7 +583,9 @@ class AdaptiveQueryExecSuite |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| - checkNumLocalShuffleReads(adaptivePlan) |
| + // Comet shuffle changes shuffle metrics, |
| + // so we can't check the number of local shuffle reads. |
| + // checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.isEmpty) |
| @@ -569,7 +594,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Broadcast exchange reuse across subqueries") { |
| + test("Broadcast exchange reuse across subqueries", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", |
| @@ -664,7 +690,8 @@ class AdaptiveQueryExecSuite |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| // There is still a SMJ, and its two shuffles can't apply local read. |
| - checkNumLocalShuffleReads(adaptivePlan, 2) |
| + // Comet shuffle changes shuffle metrics |
| + // checkNumLocalShuffleReads(adaptivePlan, 2) |
| } |
| } |
| |
| @@ -786,7 +813,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-29544: adaptive skew join with different join types") { |
| + test("SPARK-29544: adaptive skew join with different join types", |
| + IgnoreComet("Comet shuffle has different partition metrics")) { |
| Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => |
| def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") { |
| findTopLevelSortMergeJoin(plan) |
| @@ -1004,7 +1032,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("metrics of the shuffle read") { |
| + test("metrics of the shuffle read", |
| + IgnoreComet("Comet shuffle changes the metrics")) { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT key FROM testData GROUP BY key") |
| @@ -1599,7 +1628,7 @@ class AdaptiveQueryExecSuite |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") |
| assert(collect(adaptivePlan) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| }.length == 1) |
| } |
| } |
| @@ -1679,7 +1708,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-33551: Do not use AQE shuffle read for repartition") { |
| + test("SPARK-33551: Do not use AQE shuffle read for repartition", |
| + IgnoreComet("Comet shuffle changes partition size")) { |
| def hasRepartitionShuffle(plan: SparkPlan): Boolean = { |
| find(plan) { |
| case s: ShuffleExchangeLike => |
| @@ -1864,6 +1894,9 @@ class AdaptiveQueryExecSuite |
| def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = { |
| assert(collect(ds.queryExecution.executedPlan) { |
| case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s |
| + case c: CometShuffleExchangeExec |
| + if c.originalPlan.shuffleOrigin == origin && |
| + c.originalPlan.numPartitions == 2 => c |
| }.size == 1) |
| ds.collect() |
| val plan = ds.queryExecution.executedPlan |
| @@ -1872,6 +1905,9 @@ class AdaptiveQueryExecSuite |
| }.isEmpty) |
| assert(collect(plan) { |
| case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s |
| + case c: CometShuffleExchangeExec |
| + if c.originalPlan.shuffleOrigin == origin && |
| + c.originalPlan.numPartitions == 2 => c |
| }.size == 1) |
| checkAnswer(ds, testData) |
| } |
| @@ -2028,7 +2064,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-35264: Support AQE side shuffled hash join formula") { |
| + test("SPARK-35264: Support AQE side shuffled hash join formula", |
| + IgnoreComet("Comet shuffle changes the partition size")) { |
| withTempView("t1", "t2") { |
| def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = { |
| Seq("100", "100000").foreach { size => |
| @@ -2114,7 +2151,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") { |
| + test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withTempView("v") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| @@ -2213,7 +2251,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + |
| s"JOIN skewData2 ON key1 = key2 GROUP BY key1") |
| val shuffles1 = collect(adaptive1) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffles1.size == 3) |
| // shuffles1.head is the top-level shuffle under the Aggregate operator |
| @@ -2226,7 +2264,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + |
| s"JOIN skewData2 ON key1 = key2") |
| val shuffles2 = collect(adaptive2) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| if (hasRequiredDistribution) { |
| assert(shuffles2.size == 3) |
| @@ -2260,7 +2298,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-35794: Allow custom plugin for cost evaluator") { |
| + test("SPARK-35794: Allow custom plugin for cost evaluator", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| CostEvaluator.instantiate( |
| classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) |
| intercept[IllegalArgumentException] { |
| @@ -2404,6 +2443,7 @@ class AdaptiveQueryExecSuite |
| val (_, adaptive) = runAdaptiveAndVerifyResult(query) |
| assert(adaptive.collect { |
| case sort: SortExec => sort |
| + case sort: CometSortExec => sort |
| }.size == 1) |
| val read = collect(adaptive) { |
| case read: AQEShuffleReadExec => read |
| @@ -2421,7 +2461,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-37357: Add small partition factor for rebalance partitions") { |
| + test("SPARK-37357: Add small partition factor for rebalance partitions", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withTempView("v") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", |
| @@ -2533,7 +2574,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "JOIN skewData3 ON value2 = value3") |
| val shuffles1 = collect(adaptive1) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffles1.size == 4) |
| val smj1 = findTopLevelSortMergeJoin(adaptive1) |
| @@ -2544,7 +2585,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "JOIN skewData3 ON value1 = value3") |
| val shuffles2 = collect(adaptive2) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffles2.size == 4) |
| val smj2 = findTopLevelSortMergeJoin(adaptive2) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala |
| index bd9c79e5b96..2ada8c28842 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala |
| @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.SchemaPruningTest |
| import org.apache.spark.sql.catalyst.expressions.Concat |
| import org.apache.spark.sql.catalyst.parser.CatalystSqlParser |
| import org.apache.spark.sql.catalyst.plans.logical.Expand |
| +import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} |
| import org.apache.spark.sql.execution.FileSourceScanExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.functions._ |
| @@ -867,6 +868,8 @@ abstract class SchemaPruningSuite |
| val fileSourceScanSchemata = |
| collect(df.queryExecution.executedPlan) { |
| case scan: FileSourceScanExec => scan.requiredSchema |
| + case scan: CometScanExec => scan.requiredSchema |
| + case scan: CometNativeScanExec => scan.requiredSchema |
| } |
| assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, |
| s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala |
| index ce43edb79c1..8436cb727c6 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala |
| @@ -17,9 +17,10 @@ |
| |
| package org.apache.spark.sql.execution.datasources |
| |
| -import org.apache.spark.sql.{QueryTest, Row} |
| +import org.apache.spark.sql.{IgnoreComet, QueryTest, Row} |
| import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder} |
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} |
| +import org.apache.spark.sql.comet.CometSortExec |
| import org.apache.spark.sql.execution.{QueryExecution, SortExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -225,6 +226,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write |
| // assert the outer most sort in the executed plan |
| assert(plan.collectFirst { |
| case s: SortExec => s |
| + case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] |
| }.exists { |
| case SortExec(Seq( |
| SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), |
| @@ -272,6 +274,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write |
| // assert the outer most sort in the executed plan |
| assert(plan.collectFirst { |
| case s: SortExec => s |
| + case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] |
| }.exists { |
| case SortExec(Seq( |
| SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), |
| @@ -305,7 +308,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write |
| } |
| } |
| |
| - test("v1 write with AQE changing SMJ to BHJ") { |
| + test("v1 write with AQE changing SMJ to BHJ", |
| + IgnoreComet("TODO: Comet SMJ to BHJ by AQE")) { |
| withPlannedWrite { enabled => |
| withTable("t") { |
| sql( |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala |
| index 1d2e467c94c..3ea82cd1a3f 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala |
| @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path} |
| import org.mockito.Mockito.{mock, when} |
| |
| import org.apache.spark.SparkException |
| -import org.apache.spark.sql.{DataFrame, QueryTest, Row} |
| +import org.apache.spark.sql.{DataFrame, IgnoreCometSuite, QueryTest, Row} |
| import org.apache.spark.sql.catalyst.encoders.RowEncoder |
| import org.apache.spark.sql.execution.datasources.PartitionedFile |
| import org.apache.spark.sql.functions.col |
| @@ -38,7 +38,9 @@ import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.sql.types._ |
| import org.apache.spark.util.Utils |
| |
| -class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { |
| +// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow. |
| +// Since it isn't related to Comet, we disable it for now. |
| +class BinaryFileFormatSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite { |
| import BinaryFileFormat._ |
| |
| private var testDir: String = _ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala |
| index 07e2849ce6f..3e73645b638 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala |
| @@ -28,7 +28,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat |
| |
| import org.apache.spark.TestUtils |
| import org.apache.spark.memory.MemoryMode |
| -import org.apache.spark.sql.Row |
| +import org.apache.spark.sql.{IgnoreComet, Row} |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| @@ -201,7 +201,8 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess |
| } |
| } |
| |
| - test("parquet v2 pages - rle encoding for boolean value columns") { |
| + test("parquet v2 pages - rle encoding for boolean value columns", |
| + IgnoreComet("Comet doesn't support RLE encoding yet")) { |
| val extraOptions = Map[String, String]( |
| ParquetOutputFormat.WRITER_VERSION -> ParquetProperties.WriterVersion.PARQUET_2_0.toString |
| ) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala |
| index 104b4e416cd..37ea65081e4 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala |
| @@ -1096,7 +1096,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| // When a filter is pushed to Parquet, Parquet can apply it to every row. |
| // So, we can check the number of rows returned from the Parquet |
| // to make sure our filter pushdown work. |
| - assert(stripSparkFilter(df).count == 1) |
| + // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies |
| + // on Spark to apply the data filters after columnar batches are returned |
| + if (!isCometEnabled) { |
| + assert(stripSparkFilter(df).count == 1) |
| + } |
| } |
| } |
| } |
| @@ -1499,7 +1503,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("Filters should be pushed down for vectorized Parquet reader at row group level") { |
| + test("Filters should be pushed down for vectorized Parquet reader at row group level", |
| + IgnoreCometNativeScan("Native scans do not support the tested accumulator")) { |
| import testImplicits._ |
| |
| withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", |
| @@ -1581,7 +1586,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| // than the total length but should not be a single record. |
| // Note that, if record level filtering is enabled, it should be a single record. |
| // If no filter is pushed down to Parquet, it should be the total length of data. |
| - assert(actual > 1 && actual < data.length) |
| + // Only enable Comet test iff it's scan only, since with native execution |
| + // `stripSparkFilter` can't remove the native filter |
| + if (!isCometEnabled || isCometScanOnly) { |
| + assert(actual > 1 && actual < data.length) |
| + } |
| } |
| } |
| } |
| @@ -1608,7 +1617,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| // than the total length but should not be a single record. |
| // Note that, if record level filtering is enabled, it should be a single record. |
| // If no filter is pushed down to Parquet, it should be the total length of data. |
| - assert(actual > 1 && actual < data.length) |
| + // Only enable Comet test iff it's scan only, since with native execution |
| + // `stripSparkFilter` can't remove the native filter |
| + if (!isCometEnabled || isCometScanOnly) { |
| + assert(actual > 1 && actual < data.length) |
| + } |
| } |
| } |
| } |
| @@ -1744,7 +1757,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { |
| + test("SPARK-17091: Convert IN predicate to Parquet filter push-down", |
| + IgnoreCometNativeScan("Comet has different push-down behavior")) { |
| val schema = StructType(Seq( |
| StructField("a", IntegerType, nullable = false) |
| )) |
| @@ -1985,7 +1999,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("Support Parquet column index") { |
| + test("Support Parquet column index", |
| + IgnoreComet("Comet doesn't support Parquet column index yet")) { |
| // block 1: |
| // null count min max |
| // page-0 0 0 99 |
| @@ -2045,7 +2060,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("SPARK-34562: Bloom filter push down") { |
| + test("SPARK-34562: Bloom filter push down", |
| + IgnoreCometNativeScan("Native scans do not support the tested accumulator")) { |
| withTempPath { dir => |
| val path = dir.getCanonicalPath |
| spark.range(100).selectExpr("id * 2 AS id") |
| @@ -2277,7 +2293,11 @@ class ParquetV1FilterSuite extends ParquetFilterSuite { |
| assert(pushedParquetFilters.exists(_.getClass === filterClass), |
| s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.") |
| |
| - checker(stripSparkFilter(query), expected) |
| + // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies |
| + // on Spark to apply the data filters after columnar batches are returned |
| + if (!isCometEnabled) { |
| + checker(stripSparkFilter(query), expected) |
| + } |
| } else { |
| assert(selectedFilters.isEmpty, "There is filter pushed down") |
| } |
| @@ -2337,7 +2357,11 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { |
| assert(pushedParquetFilters.exists(_.getClass === filterClass), |
| s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.") |
| |
| - checker(stripSparkFilter(query), expected) |
| + // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies |
| + // on Spark to apply the data filters after columnar batches are returned |
| + if (!isCometEnabled) { |
| + checker(stripSparkFilter(query), expected) |
| + } |
| |
| case _ => |
| throw new AnalysisException("Can not match ParquetTable in the query.") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala |
| index 8670d95c65e..b624c3811dd 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala |
| @@ -1335,7 +1335,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession |
| } |
| } |
| |
| - test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings") { |
| + test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings", |
| + IgnoreComet("Comet doesn't support DELTA encoding yet")) { |
| withAllParquetReaders { |
| checkAnswer( |
| // "fruit" column in this file is encoded using DELTA_LENGTH_BYTE_ARRAY. |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala |
| index 29cb224c878..44837aa953b 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala |
| @@ -978,7 +978,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| } |
| } |
| |
| - test("SPARK-26677: negated null-safe equality comparison should not filter matched row groups") { |
| + test("SPARK-26677: negated null-safe equality comparison should not filter matched row groups", |
| + IgnoreCometNativeScan("Native scans had the filter pushed into DF operator, cannot strip")) { |
| withAllParquetReaders { |
| withTempPath { path => |
| // Repeated values for dictionary encoding. |
| @@ -1047,7 +1048,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| checkAnswer(readParquet(schema, path), df) |
| } |
| |
| - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { |
| + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", |
| + "spark.comet.enabled" -> "false") { |
| val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)" |
| checkAnswer(readParquet(schema1, path), df) |
| val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)" |
| @@ -1069,7 +1071,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| val df = sql(s"SELECT 1 a, 123456 b, ${Int.MaxValue.toLong * 10} c, CAST('1.2' AS BINARY) d") |
| df.write.parquet(path.toString) |
| |
| - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { |
| + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", |
| + "spark.comet.enabled" -> "false") { |
| checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00")) |
| checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null)) |
| checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 123456.0")) |
| @@ -1128,7 +1131,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| .where(s"a < ${Long.MaxValue}") |
| .collect() |
| } |
| - assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) |
| + assert(exception.getMessage.contains("Column: [a], Expected: bigint, Found: INT32")) |
| } |
| } |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala |
| index 240bb4e6dcb..8287ffa03ca 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala |
| @@ -21,7 +21,7 @@ import java.nio.file.{Files, Paths, StandardCopyOption} |
| import java.sql.{Date, Timestamp} |
| |
| import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException} |
| -import org.apache.spark.sql.{QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY} |
| +import org.apache.spark.sql.{IgnoreCometSuite, QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY} |
| import org.apache.spark.sql.catalyst.util.DateTimeTestUtils |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, ParquetOutputTimestampType} |
| @@ -30,9 +30,11 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType.{INT96, |
| import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.tags.SlowSQLTest |
| |
| +// Comet is disabled for this suite because it doesn't support datetime rebase mode |
| abstract class ParquetRebaseDatetimeSuite |
| extends QueryTest |
| with ParquetTest |
| + with IgnoreCometSuite |
| with SharedSparkSession { |
| |
| import testImplicits._ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala |
| index 351c6d698fc..583d9225cca 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala |
| @@ -20,12 +20,14 @@ import java.io.File |
| |
| import scala.collection.JavaConverters._ |
| |
| +import org.apache.comet.CometConf |
| import org.apache.hadoop.fs.Path |
| import org.apache.parquet.column.ParquetProperties._ |
| import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat} |
| import org.apache.parquet.hadoop.ParquetWriter.DEFAULT_BLOCK_SIZE |
| |
| import org.apache.spark.sql.QueryTest |
| +import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec} |
| import org.apache.spark.sql.execution.FileSourceScanExec |
| import org.apache.spark.sql.execution.datasources.FileFormat |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| @@ -172,6 +174,8 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { |
| |
| private def testRowIndexGeneration(label: String, conf: RowIndexTestConf): Unit = { |
| test (s"$label - ${conf.desc}") { |
| + // native_datafusion Parquet scan does not support row index generation. |
| + assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) |
| withSQLConf(conf.sqlConfs: _*) { |
| withTempPath { path => |
| val rowIndexColName = FileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME |
| @@ -230,6 +234,12 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { |
| case f: FileSourceScanExec => |
| numPartitions += f.inputRDD.partitions.length |
| numOutputRows += f.metrics("numOutputRows").value |
| + case b: CometScanExec => |
| + numPartitions += b.inputRDD.partitions.length |
| + numOutputRows += b.metrics("numOutputRows").value |
| + case b: CometBatchScanExec => |
| + numPartitions += b.inputRDD.partitions.length |
| + numOutputRows += b.metrics("numOutputRows").value |
| case _ => |
| } |
| assert(numPartitions > 0) |
| @@ -291,6 +301,8 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { |
| val conf = RowIndexTestConf(useDataSourceV2 = useDataSourceV2) |
| |
| test(s"invalid row index column type - ${conf.desc}") { |
| + // native_datafusion Parquet scan does not support row index generation. |
| + assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) |
| withSQLConf(conf.sqlConfs: _*) { |
| withTempPath{ path => |
| val df = spark.range(0, 10, 1, 1).toDF("id") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala |
| index 5c0b7def039..151184bc98c 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala |
| @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet |
| import org.apache.spark.SparkConf |
| import org.apache.spark.sql.DataFrame |
| import org.apache.spark.sql.catalyst.parser.CatalystSqlParser |
| +import org.apache.spark.sql.comet.CometBatchScanExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.datasources.SchemaPruningSuite |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| @@ -56,6 +57,7 @@ class ParquetV2SchemaPruningSuite extends ParquetSchemaPruningSuite { |
| val fileSourceScanSchemata = |
| collect(df.queryExecution.executedPlan) { |
| case scan: BatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema |
| + case scan: CometBatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema |
| } |
| assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, |
| s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala |
| index bf5c51b89bb..ca22370ca3b 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala |
| @@ -27,6 +27,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName |
| import org.apache.parquet.schema.Type._ |
| |
| import org.apache.spark.SparkException |
| +import org.apache.spark.sql.IgnoreComet |
| import org.apache.spark.sql.catalyst.ScalaReflection |
| import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException |
| import org.apache.spark.sql.functions.desc |
| @@ -1016,7 +1017,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |
| e |
| } |
| |
| - test("schema mismatch failure error message for parquet reader") { |
| + test("schema mismatch failure error message for parquet reader", |
| + IgnoreComet("Comet doesn't work with vectorizedReaderEnabled = false")) { |
| withTempPath { dir => |
| val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false) |
| val expectedMessage = "Encountered error while reading file" |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala |
| index 3a0bd35cb70..b28f06a757f 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala |
| @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.debug |
| import java.io.ByteArrayOutputStream |
| |
| import org.apache.spark.rdd.RDD |
| +import org.apache.spark.sql.IgnoreComet |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.Attribute |
| import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext |
| @@ -124,7 +125,8 @@ class DebuggingSuite extends DebuggingSuiteBase with DisableAdaptiveExecutionSui |
| | id LongType: {}""".stripMargin)) |
| } |
| |
| - test("SPARK-28537: DebugExec cannot debug columnar related queries") { |
| + test("SPARK-28537: DebugExec cannot debug columnar related queries", |
| + IgnoreComet("Comet does not use FileScan")) { |
| withTempPath { workDir => |
| val workDirPath = workDir.getAbsolutePath |
| val input = spark.range(5).toDF("id") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala |
| index 26e61c6b58d..cb09d7e116a 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala |
| @@ -45,8 +45,10 @@ import org.apache.spark.sql.util.QueryExecutionListener |
| import org.apache.spark.util.{AccumulatorContext, JsonProtocol} |
| |
| // Disable AQE because metric info is different with AQE on/off |
| +// This test suite runs tests against the metrics of physical operators. |
| +// Disabling it for Comet because the metrics are different with Comet enabled. |
| class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils |
| - with DisableAdaptiveExecutionSuite { |
| + with DisableAdaptiveExecutionSuite with IgnoreCometSuite { |
| import testImplicits._ |
| |
| /** |
| @@ -737,7 +739,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils |
| } |
| } |
| |
| - test("SPARK-26327: FileSourceScanExec metrics") { |
| + test("SPARK-26327: FileSourceScanExec metrics", |
| + IgnoreComet("Spark uses row-based Parquet reader while Comet is vectorized")) { |
| withTable("testDataForScan") { |
| spark.range(10).selectExpr("id", "id % 3 as p") |
| .write.partitionBy("p").saveAsTable("testDataForScan") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala |
| index 0ab8691801d..d9125f658ad 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala |
| @@ -18,6 +18,7 @@ |
| package org.apache.spark.sql.execution.python |
| |
| import org.apache.spark.sql.catalyst.plans.logical.{ArrowEvalPython, BatchEvalPython, Limit, LocalLimit} |
| +import org.apache.spark.sql.comet._ |
| import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest} |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan |
| @@ -108,6 +109,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: FileSourceScanExec => scan |
| + case scan: CometScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| assert(scanNodes.head.output.map(_.name) == Seq("a")) |
| @@ -120,11 +122,16 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: FileSourceScanExec => scan |
| + case scan: CometScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| // $"a" is not null and $"a" > 1 |
| - assert(scanNodes.head.dataFilters.length == 2) |
| - assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a")) |
| + val dataFilters = scanNodes.head match { |
| + case scan: FileSourceScanExec => scan.dataFilters |
| + case scan: CometScanExec => scan.dataFilters |
| + } |
| + assert(dataFilters.length == 2) |
| + assert(dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a")) |
| } |
| } |
| } |
| @@ -145,6 +152,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: BatchScanExec => scan |
| + case scan: CometBatchScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| assert(scanNodes.head.output.map(_.name) == Seq("a")) |
| @@ -157,6 +165,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: BatchScanExec => scan |
| + case scan: CometBatchScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| // $"a" is not null and $"a" > 1 |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala |
| index d083cac48ff..3c11bcde807 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala |
| @@ -37,8 +37,10 @@ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, |
| import org.apache.spark.sql.streaming.util.StreamManualClock |
| import org.apache.spark.util.Utils |
| |
| +// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow. |
| +// Since it isn't related to Comet, we disable it for now. |
| class AsyncProgressTrackingMicroBatchExecutionSuite |
| - extends StreamTest with BeforeAndAfter with Matchers { |
| + extends StreamTest with BeforeAndAfter with Matchers with IgnoreCometSuite { |
| |
| import testImplicits._ |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala |
| index 266bb343526..f8ad838e2b2 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala |
| @@ -19,15 +19,18 @@ package org.apache.spark.sql.sources |
| |
| import scala.util.Random |
| |
| +import org.apache.comet.CometConf |
| + |
| import org.apache.spark.sql._ |
| import org.apache.spark.sql.catalyst.catalog.BucketSpec |
| import org.apache.spark.sql.catalyst.expressions |
| import org.apache.spark.sql.catalyst.expressions._ |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning |
| -import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan} |
| +import org.apache.spark.sql.comet._ |
| +import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan} |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution} |
| import org.apache.spark.sql.execution.datasources.BucketingUtils |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.joins.SortMergeJoinExec |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -101,12 +104,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| } |
| |
| - private def getFileScan(plan: SparkPlan): FileSourceScanExec = { |
| - val fileScan = collect(plan) { case f: FileSourceScanExec => f } |
| + private def getFileScan(plan: SparkPlan): SparkPlan = { |
| + val fileScan = collect(plan) { |
| + case f: FileSourceScanExec => f |
| + case f: CometScanExec => f |
| + case f: CometNativeScanExec => f |
| + } |
| assert(fileScan.nonEmpty, plan) |
| fileScan.head |
| } |
| |
| + private def getBucketScan(plan: SparkPlan): Boolean = getFileScan(plan) match { |
| + case fs: FileSourceScanExec => fs.bucketedScan |
| + case bs: CometScanExec => bs.bucketedScan |
| + case ns: CometNativeScanExec => ns.bucketedScan |
| + } |
| + |
| // To verify if the bucket pruning works, this function checks two conditions: |
| // 1) Check if the pruned buckets (before filtering) are empty. |
| // 2) Verify the final result is the same as the expected one |
| @@ -155,7 +168,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition) |
| .queryExecution.executedPlan |
| val fileScan = getFileScan(planWithoutBucketedScan) |
| - assert(!fileScan.bucketedScan, s"except no bucketed scan but found\n$fileScan") |
| + val bucketedScan = getBucketScan(planWithoutBucketedScan) |
| + assert(!bucketedScan, s"except no bucketed scan but found\n$fileScan") |
| |
| val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType |
| val rowsWithInvalidBuckets = fileScan.execute().filter(row => { |
| @@ -451,28 +465,54 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) { |
| val executedPlan = |
| joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan |
| - assert(executedPlan.isInstanceOf[SortMergeJoinExec]) |
| - executedPlan.asInstanceOf[SortMergeJoinExec] |
| + executedPlan match { |
| + case s: SortMergeJoinExec => s |
| + case b: CometSortMergeJoinExec => |
| + b.originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| } else { |
| val executedPlan = joined.queryExecution.executedPlan |
| - assert(executedPlan.isInstanceOf[SortMergeJoinExec]) |
| - executedPlan.asInstanceOf[SortMergeJoinExec] |
| + executedPlan match { |
| + case s: SortMergeJoinExec => s |
| + case ColumnarToRowExec(child) => |
| + child.asInstanceOf[CometSortMergeJoinExec].originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case CometColumnarToRowExec(child) => |
| + child.asInstanceOf[CometSortMergeJoinExec].originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case CometNativeColumnarToRowExec(child) => |
| + child.asInstanceOf[CometSortMergeJoinExec].originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| } |
| |
| // check existence of shuffle |
| assert( |
| - joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft, |
| + joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft, |
| s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") |
| assert( |
| - joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight, |
| + joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight, |
| s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") |
| |
| // check existence of sort |
| assert( |
| - joinOperator.left.exists(_.isInstanceOf[SortExec]) == sortLeft, |
| + joinOperator.left.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] && |
| + op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortLeft, |
| s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}") |
| assert( |
| - joinOperator.right.exists(_.isInstanceOf[SortExec]) == sortRight, |
| + joinOperator.right.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] && |
| + op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortRight, |
| s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}") |
| |
| // check the output partitioning |
| @@ -835,11 +875,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") |
| |
| val scanDF = spark.table("bucketed_table").select("j") |
| - assert(!getFileScan(scanDF.queryExecution.executedPlan).bucketedScan) |
| + assert(!getBucketScan(scanDF.queryExecution.executedPlan)) |
| checkAnswer(scanDF, df1.select("j")) |
| |
| val aggDF = spark.table("bucketed_table").groupBy("j").agg(max("k")) |
| - assert(!getFileScan(aggDF.queryExecution.executedPlan).bucketedScan) |
| + assert(!getBucketScan(aggDF.queryExecution.executedPlan)) |
| checkAnswer(aggDF, df1.groupBy("j").agg(max("k"))) |
| } |
| } |
| @@ -894,7 +934,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| |
| test("SPARK-29655 Read bucketed tables obeys spark.sql.shuffle.partitions") { |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| withSQLConf( |
| + CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "5", |
| SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") { |
| val bucketSpec = Some(BucketSpec(6, Seq("i", "j"), Nil)) |
| @@ -913,7 +956,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| |
| test("SPARK-32767 Bucket join should work if SHUFFLE_PARTITIONS larger than bucket number") { |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| withSQLConf( |
| + CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "9", |
| SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10") { |
| |
| @@ -943,7 +989,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| |
| test("bucket coalescing eliminates shuffle") { |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| withSQLConf( |
| + CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false", |
| SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { |
| // The side with bucketedTableTestSpec1 will be coalesced to have 4 output partitions. |
| @@ -1026,15 +1075,26 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| expectedNumShuffles: Int, |
| expectedCoalescedNumBuckets: Option[Int]): Unit = { |
| val plan = sql(query).queryExecution.executedPlan |
| - val shuffles = plan.collect { case s: ShuffleExchangeExec => s } |
| + val shuffles = plan.collect { |
| + case s: ShuffleExchangeLike => s |
| + } |
| assert(shuffles.length == expectedNumShuffles) |
| |
| val scans = plan.collect { |
| case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f |
| + case b: CometScanExec if b.optionalNumCoalescedBuckets.isDefined => b |
| + case b: CometNativeScanExec if b.optionalNumCoalescedBuckets.isDefined => b |
| } |
| if (expectedCoalescedNumBuckets.isDefined) { |
| assert(scans.length == 1) |
| - assert(scans.head.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + scans.head match { |
| + case f: FileSourceScanExec => |
| + assert(f.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + case b: CometScanExec => |
| + assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + case b: CometNativeScanExec => |
| + assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + } |
| } else { |
| assert(scans.isEmpty) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala |
| index b5f6d2f9f68..277784a92af 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala |
| @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources |
| import java.io.File |
| |
| import org.apache.spark.SparkException |
| -import org.apache.spark.sql.AnalysisException |
| +import org.apache.spark.sql.{AnalysisException, IgnoreCometSuite} |
| import org.apache.spark.sql.catalyst.TableIdentifier |
| import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTableType} |
| import org.apache.spark.sql.catalyst.parser.ParseException |
| @@ -28,7 +28,10 @@ import org.apache.spark.sql.internal.SQLConf.BUCKETING_MAX_BUCKETS |
| import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.util.Utils |
| |
| -class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession { |
| +// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow. |
| +// Since it isn't related to Comet, we disable it for now. |
| +class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession |
| + with IgnoreCometSuite { |
| import testImplicits._ |
| |
| protected override lazy val sql = spark.sql _ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala |
| index 1f55742cd67..f20129d9dd8 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala |
| @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources |
| import org.apache.spark.sql.QueryTest |
| import org.apache.spark.sql.catalyst.expressions.AttributeReference |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning |
| +import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} |
| import org.apache.spark.sql.execution.FileSourceScanExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| @@ -71,7 +72,11 @@ abstract class DisableUnnecessaryBucketedScanSuite |
| |
| def checkNumBucketedScan(query: String, expectedNumBucketedScan: Int): Unit = { |
| val plan = sql(query).queryExecution.executedPlan |
| - val bucketedScan = collect(plan) { case s: FileSourceScanExec if s.bucketedScan => s } |
| + val bucketedScan = collect(plan) { |
| + case s: FileSourceScanExec if s.bucketedScan => s |
| + case s: CometScanExec if s.bucketedScan => s |
| + case s: CometNativeScanExec if s.bucketedScan => s |
| + } |
| assert(bucketedScan.length == expectedNumBucketedScan) |
| } |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala |
| index 75f440caefc..36b1146bc3a 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala |
| @@ -34,6 +34,7 @@ import org.apache.spark.paths.SparkPath |
| import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} |
| import org.apache.spark.sql.{AnalysisException, DataFrame} |
| import org.apache.spark.sql.catalyst.util.stringToFile |
| +import org.apache.spark.sql.comet.CometBatchScanExec |
| import org.apache.spark.sql.execution.DataSourceScanExec |
| import org.apache.spark.sql.execution.datasources._ |
| import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, FileScan, FileTable} |
| @@ -748,6 +749,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite { |
| val fileScan = df.queryExecution.executedPlan.collect { |
| case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] => |
| batch.scan.asInstanceOf[FileScan] |
| + case batch: CometBatchScanExec if batch.scan.isInstanceOf[FileScan] => |
| + batch.scan.asInstanceOf[FileScan] |
| }.headOption.getOrElse { |
| fail(s"No FileScan in query\n${df.queryExecution}") |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala |
| index 6aa7d0945c7..ad26ad833e2 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala |
| @@ -46,6 +46,7 @@ case class RunningCount(count: Long) |
| |
| case class Result(key: Long, count: Int) |
| |
| +// TODO: fix Comet to enable this suite |
| @SlowSQLTest |
| class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala |
| index ef5b8a769fe..84fe1bfabc9 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala |
| @@ -37,6 +37,7 @@ import org.apache.spark.sql._ |
| import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpression} |
| import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| +import org.apache.spark.sql.comet.CometLocalLimitExec |
| import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} |
| import org.apache.spark.sql.execution.command.ExplainCommand |
| import org.apache.spark.sql.execution.streaming._ |
| @@ -1103,11 +1104,12 @@ class StreamSuite extends StreamTest { |
| val localLimits = execPlan.collect { |
| case l: LocalLimitExec => l |
| case l: StreamingLocalLimitExec => l |
| + case l: CometLocalLimitExec => l |
| } |
| |
| require( |
| localLimits.size == 1, |
| - s"Cant verify local limit optimization with this plan:\n$execPlan") |
| + s"Cant verify local limit optimization ${localLimits.size} with this plan:\n$execPlan") |
| |
| if (expectStreamingLimit) { |
| assert( |
| @@ -1115,7 +1117,8 @@ class StreamSuite extends StreamTest { |
| s"Local limit was not StreamingLocalLimitExec:\n$execPlan") |
| } else { |
| assert( |
| - localLimits.head.isInstanceOf[LocalLimitExec], |
| + localLimits.head.isInstanceOf[LocalLimitExec] || |
| + localLimits.head.isInstanceOf[CometLocalLimitExec], |
| s"Local limit was not LocalLimitExec:\n$execPlan") |
| } |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala |
| index b4c4ec7acbf..20579284856 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala |
| @@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils |
| import org.scalatest.Assertions |
| |
| import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution |
| +import org.apache.spark.sql.comet.CometHashAggregateExec |
| import org.apache.spark.sql.execution.aggregate.BaseAggregateExec |
| import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec} |
| import org.apache.spark.sql.functions.count |
| @@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest |
| // verify aggregations in between, except partial aggregation |
| val allAggregateExecs = query.lastExecution.executedPlan.collect { |
| case a: BaseAggregateExec => a |
| + case c: CometHashAggregateExec => c.originalPlan |
| } |
| |
| val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { |
| @@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest |
| // verify aggregations in between, except partial aggregation |
| val allAggregateExecs = executedPlan.collect { |
| case a: BaseAggregateExec => a |
| + case c: CometHashAggregateExec => c.originalPlan |
| } |
| |
| val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala |
| index 4d92e270539..33f1c2eb75e 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala |
| @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation |
| import org.apache.spark.sql.{DataFrame, Row, SparkSession} |
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} |
| import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId} |
| import org.apache.spark.sql.functions._ |
| @@ -619,14 +619,28 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { |
| |
| val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS) |
| |
| - assert(query.lastExecution.executedPlan.collect { |
| - case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _, |
| - ShuffleExchangeExec(opA: HashPartitioning, _, _), |
| - ShuffleExchangeExec(opB: HashPartitioning, _, _)) |
| - if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") |
| - && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") |
| - && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j |
| - }.size == 1) |
| + val join = query.lastExecution.executedPlan.collect { |
| + case j: StreamingSymmetricHashJoinExec => j |
| + }.head |
| + val opA = join.left.collect { |
| + case s: ShuffleExchangeLike |
| + if s.outputPartitioning.isInstanceOf[HashPartitioning] && |
| + partitionExpressionsColumns( |
| + s.outputPartitioning |
| + .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => |
| + s.outputPartitioning |
| + .asInstanceOf[HashPartitioning] |
| + }.head |
| + val opB = join.right.collect { |
| + case s: ShuffleExchangeLike |
| + if s.outputPartitioning.isInstanceOf[HashPartitioning] && |
| + partitionExpressionsColumns( |
| + s.outputPartitioning |
| + .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => |
| + s.outputPartitioning |
| + .asInstanceOf[HashPartitioning] |
| + }.head |
| + assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions) |
| }) |
| } |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala |
| index abe606ad9c1..2d930b64cca 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala |
| @@ -22,7 +22,7 @@ import java.util |
| |
| import org.scalatest.BeforeAndAfter |
| |
| -import org.apache.spark.sql.{AnalysisException, Row, SaveMode} |
| +import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row, SaveMode} |
| import org.apache.spark.sql.catalyst.TableIdentifier |
| import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException |
| import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} |
| @@ -327,7 +327,8 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { |
| } |
| } |
| |
| - test("explain with table on DSv1 data source") { |
| + test("explain with table on DSv1 data source", |
| + IgnoreComet("Comet explain output is different")) { |
| val tblSourceName = "tbl_src" |
| val tblTargetName = "tbl_target" |
| val tblSourceQualified = s"default.$tblSourceName" |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala |
| index dd55fcfe42c..a1d390c93d0 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala |
| @@ -27,6 +27,7 @@ import scala.concurrent.duration._ |
| import scala.language.implicitConversions |
| import scala.util.control.NonFatal |
| |
| +import org.apache.comet.CometConf |
| import org.apache.hadoop.fs.Path |
| import org.scalactic.source.Position |
| import org.scalatest.{BeforeAndAfterAll, Suite, Tag} |
| @@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest |
| import org.apache.spark.sql.catalyst.plans.PlanTestBase |
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan |
| import org.apache.spark.sql.catalyst.util._ |
| +import org.apache.spark.sql.comet._ |
| import org.apache.spark.sql.execution.FilterExec |
| import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution |
| import org.apache.spark.sql.execution.datasources.DataSourceUtils |
| @@ -118,7 +120,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with |
| } |
| |
| override protected def test(testName: String, testTags: Tag*)(testFun: => Any) |
| - (implicit pos: Position): Unit = { |
| + (implicit pos: Position): Unit = { |
| if (testTags.exists(_.isInstanceOf[DisableAdaptiveExecution])) { |
| super.test(testName, testTags: _*) { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { |
| @@ -126,7 +128,28 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with |
| } |
| } |
| } else { |
| - super.test(testName, testTags: _*)(testFun) |
| + if (isCometEnabled && testTags.exists(_.isInstanceOf[IgnoreComet])) { |
| + ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun) |
| + } else { |
| + val cometScanImpl = CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) |
| + val isNativeIcebergCompat = cometScanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT || |
| + cometScanImpl == CometConf.SCAN_AUTO |
| + val isNativeDataFusion = cometScanImpl == CometConf.SCAN_NATIVE_DATAFUSION || |
| + cometScanImpl == CometConf.SCAN_AUTO |
| + if (isCometEnabled && isNativeIcebergCompat && |
| + testTags.exists(_.isInstanceOf[IgnoreCometNativeIcebergCompat])) { |
| + ignore(testName + " (disabled for NATIVE_ICEBERG_COMPAT)", testTags: _*)(testFun) |
| + } else if (isCometEnabled && isNativeDataFusion && |
| + testTags.exists(_.isInstanceOf[IgnoreCometNativeDataFusion])) { |
| + ignore(testName + " (disabled for NATIVE_DATAFUSION)", testTags: _*)(testFun) |
| + } else if (isCometEnabled && (isNativeDataFusion || isNativeIcebergCompat) && |
| + testTags.exists(_.isInstanceOf[IgnoreCometNativeScan])) { |
| + ignore(testName + " (disabled for NATIVE_DATAFUSION and NATIVE_ICEBERG_COMPAT)", |
| + testTags: _*)(testFun) |
| + } else { |
| + super.test(testName, testTags: _*)(testFun) |
| + } |
| + } |
| } |
| } |
| |
| @@ -242,6 +265,29 @@ private[sql] trait SQLTestUtilsBase |
| protected override def _sqlContext: SQLContext = self.spark.sqlContext |
| } |
| |
| + /** |
| + * Whether Comet extension is enabled |
| + */ |
| + protected def isCometEnabled: Boolean = SparkSession.isCometEnabled |
| + |
| + /** |
| + * Whether to enable ansi mode This is only effective when |
| + * [[isCometEnabled]] returns true. |
| + */ |
| + protected def enableCometAnsiMode: Boolean = { |
| + val v = System.getenv("ENABLE_COMET_ANSI_MODE") |
| + v != null && v.toBoolean |
| + } |
| + |
| + /** |
| + * Whether Spark should only apply Comet scan optimization. This is only effective when |
| + * [[isCometEnabled]] returns true. |
| + */ |
| + protected def isCometScanOnly: Boolean = { |
| + val v = System.getenv("ENABLE_COMET_SCAN_ONLY") |
| + v != null && v.toBoolean |
| + } |
| + |
| protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { |
| SparkSession.setActiveSession(spark) |
| super.withSQLConf(pairs: _*)(f) |
| @@ -434,6 +480,8 @@ private[sql] trait SQLTestUtilsBase |
| val schema = df.schema |
| val withoutFilters = df.queryExecution.executedPlan.transform { |
| case FilterExec(_, child) => child |
| + case CometFilterExec(_, _, _, _, child, _) => child |
| + case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child |
| } |
| |
| spark.internalCreateDataFrame(withoutFilters.execute(), schema) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala |
| index ed2e309fa07..a5ea58146ad 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala |
| @@ -74,6 +74,31 @@ trait SharedSparkSessionBase |
| // this rule may potentially block testing of other optimization rules such as |
| // ConstantPropagation etc. |
| .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) |
| + // Enable Comet if `ENABLE_COMET` environment variable is set |
| + if (isCometEnabled) { |
| + conf |
| + .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") |
| + .set("spark.comet.enabled", "true") |
| + .set("spark.comet.parquet.respectFilterPushdown", "true") |
| + |
| + if (!isCometScanOnly) { |
| + conf |
| + .set("spark.comet.exec.enabled", "true") |
| + .set("spark.shuffle.manager", |
| + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") |
| + .set("spark.comet.exec.shuffle.enabled", "true") |
| + .set("spark.comet.memoryOverhead", "10g") |
| + } else { |
| + conf |
| + .set("spark.comet.exec.enabled", "false") |
| + .set("spark.comet.exec.shuffle.enabled", "false") |
| + } |
| + |
| + if (enableCometAnsiMode) { |
| + conf |
| + .set("spark.sql.ansi.enabled", "true") |
| + } |
| + } |
| conf.set( |
| StaticSQLConf.WAREHOUSE_PATH, |
| conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala |
| index 1510e8957f9..7618419d8ff 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala |
| @@ -43,7 +43,7 @@ class SqlResourceWithActualMetricsSuite |
| import testImplicits._ |
| |
| // Exclude nodes which may not have the metrics |
| - val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject") |
| + val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject", "RowToColumnar") |
| |
| implicit val formats = new DefaultFormats { |
| override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss") |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala |
| index 52abd248f3a..7a199931a08 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala |
| @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive |
| |
| import org.apache.spark.sql._ |
| import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} |
| +import org.apache.spark.sql.comet._ |
| import org.apache.spark.sql.execution._ |
| import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.hive.execution.HiveTableScanExec |
| @@ -35,6 +36,9 @@ abstract class DynamicPartitionPruningHiveScanSuiteBase |
| case s: FileSourceScanExec => s.partitionFilters.collect { |
| case d: DynamicPruningExpression => d.child |
| } |
| + case s: CometScanExec => s.partitionFilters.collect { |
| + case d: DynamicPruningExpression => d.child |
| + } |
| case h: HiveTableScanExec => h.partitionPruningPred.collect { |
| case d: DynamicPruningExpression => d.child |
| } |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala |
| index de3b1ffccf0..2a76d127093 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala |
| @@ -23,14 +23,15 @@ import java.util.concurrent.{Executors, TimeUnit} |
| import org.scalatest.BeforeAndAfterEach |
| |
| import org.apache.spark.metrics.source.HiveCatalogMetrics |
| -import org.apache.spark.sql.QueryTest |
| +import org.apache.spark.sql.{IgnoreCometSuite, QueryTest} |
| import org.apache.spark.sql.execution.datasources.FileStatusCache |
| import org.apache.spark.sql.hive.test.TestHiveSingleton |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SQLTestUtils |
| |
| class PartitionedTablePerfStatsSuite |
| - extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach { |
| + extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach |
| + with IgnoreCometSuite { |
| |
| override def beforeEach(): Unit = { |
| super.beforeEach() |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala |
| index a902cb3a69e..800a3acbe99 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala |
| @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} |
| import java.util.{Locale, Set} |
| |
| import com.google.common.io.Files |
| +import org.apache.comet.CometConf |
| import org.apache.hadoop.fs.{FileSystem, Path} |
| |
| import org.apache.spark.{SparkException, TestUtils} |
| @@ -838,8 +839,13 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi |
| } |
| |
| test("SPARK-2554 SumDistinct partial aggregation") { |
| - checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), |
| - sql("SELECT distinct key FROM src order by key").collect().toSeq) |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false") |
| + { |
| + checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), |
| + sql("SELECT distinct key FROM src order by key").collect().toSeq) |
| + } |
| } |
| |
| test("SPARK-4963 DataFrame sample on mutable row return wrong result") { |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala |
| index 07361cfdce9..97dab2a3506 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala |
| @@ -55,25 +55,54 @@ object TestHive |
| new SparkContext( |
| System.getProperty("spark.sql.test.master", "local[1]"), |
| "TestSQLContext", |
| - new SparkConf() |
| - .set("spark.sql.test", "") |
| - .set(SQLConf.CODEGEN_FALLBACK.key, "false") |
| - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) |
| - .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, |
| - "org.apache.spark.sql.hive.execution.PairSerDe") |
| - .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) |
| - // SPARK-8910 |
| - .set(UI_ENABLED, false) |
| - .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) |
| - // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes |
| - // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764. |
| - .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false") |
| - // Disable ConvertToLocalRelation for better test coverage. Test cases built on |
| - // LocalRelation will exercise the optimization rules better by disabling it as |
| - // this rule may potentially block testing of other optimization rules such as |
| - // ConstantPropagation etc. |
| - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName))) |
| + { |
| + val conf = new SparkConf() |
| + .set("spark.sql.test", "") |
| + .set(SQLConf.CODEGEN_FALLBACK.key, "false") |
| + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) |
| + .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, |
| + "org.apache.spark.sql.hive.execution.PairSerDe") |
| + .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) |
| + // SPARK-8910 |
| + .set(UI_ENABLED, false) |
| + .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) |
| + // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes |
| + // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764. |
| + .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false") |
| + // Disable ConvertToLocalRelation for better test coverage. Test cases built on |
| + // LocalRelation will exercise the optimization rules better by disabling it as |
| + // this rule may potentially block testing of other optimization rules such as |
| + // ConstantPropagation etc. |
| + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) |
| + |
| + if (SparkSession.isCometEnabled) { |
| + conf |
| + .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") |
| + .set("spark.comet.enabled", "true") |
| + |
| + val v = System.getenv("ENABLE_COMET_SCAN_ONLY") |
| + if (v == null || !v.toBoolean) { |
| + conf |
| + .set("spark.comet.exec.enabled", "true") |
| + .set("spark.shuffle.manager", |
| + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") |
| + .set("spark.comet.exec.shuffle.enabled", "true") |
| + } else { |
| + conf |
| + .set("spark.comet.exec.enabled", "false") |
| + .set("spark.comet.exec.shuffle.enabled", "false") |
| + } |
| + |
| + val a = System.getenv("ENABLE_COMET_ANSI_MODE") |
| + if (a != null && a.toBoolean) { |
| + conf |
| + .set("spark.sql.ansi.enabled", "true") |
| + } |
| + } |
| |
| + conf |
| + } |
| + )) |
| |
| case class TestHiveVersion(hiveClient: HiveClient) |
| extends TestHiveContext(TestHive.sparkContext, hiveClient) |