| diff --git a/pom.xml b/pom.xml |
| index edd2ad57880..77a975ea48f 100644 |
| --- a/pom.xml |
| +++ b/pom.xml |
| @@ -152,6 +152,8 @@ |
| --> |
| <ivy.version>2.5.1</ivy.version> |
| <oro.version>2.0.8</oro.version> |
| + <spark.version.short>3.5</spark.version.short> |
| + <comet.version>0.14.0-SNAPSHOT</comet.version> |
| <!-- |
| If you changes codahale.metrics.version, you also need to change |
| the link to metrics.dropwizard.io in docs/monitoring.md. |
| @@ -2840,6 +2842,25 @@ |
| <artifactId>okio</artifactId> |
| <version>${okio.version}</version> |
| </dependency> |
| + <dependency> |
| + <groupId>org.apache.datafusion</groupId> |
| + <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId> |
| + <version>${comet.version}</version> |
| + <exclusions> |
| + <exclusion> |
| + <groupId>org.apache.spark</groupId> |
| + <artifactId>spark-sql_${scala.binary.version}</artifactId> |
| + </exclusion> |
| + <exclusion> |
| + <groupId>org.apache.spark</groupId> |
| + <artifactId>spark-core_${scala.binary.version}</artifactId> |
| + </exclusion> |
| + <exclusion> |
| + <groupId>org.apache.spark</groupId> |
| + <artifactId>spark-catalyst_${scala.binary.version}</artifactId> |
| + </exclusion> |
| + </exclusions> |
| + </dependency> |
| </dependencies> |
| </dependencyManagement> |
| |
| diff --git a/sql/core/pom.xml b/sql/core/pom.xml |
| index bc00c448b80..82068d7a2eb 100644 |
| --- a/sql/core/pom.xml |
| +++ b/sql/core/pom.xml |
| @@ -77,6 +77,10 @@ |
| <groupId>org.apache.spark</groupId> |
| <artifactId>spark-tags_${scala.binary.version}</artifactId> |
| </dependency> |
| + <dependency> |
| + <groupId>org.apache.datafusion</groupId> |
| + <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId> |
| + </dependency> |
| |
| <!-- |
| This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude |
| diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala |
| index 27ae10b3d59..78e69902dfd 100644 |
| --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala |
| +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala |
| @@ -1353,6 +1353,14 @@ object SparkSession extends Logging { |
| } |
| } |
| |
| + private def loadCometExtension(sparkContext: SparkContext): Seq[String] = { |
| + if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled)) { |
| + Seq("org.apache.comet.CometSparkSessionExtensions") |
| + } else { |
| + Seq.empty |
| + } |
| + } |
| + |
| /** |
| * Initialize extensions specified in [[StaticSQLConf]]. The classes will be applied to the |
| * extensions passed into this function. |
| @@ -1362,6 +1370,7 @@ object SparkSession extends Logging { |
| extensions: SparkSessionExtensions): SparkSessionExtensions = { |
| val extensionConfClassNames = sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) |
| .getOrElse(Seq.empty) |
| + val extensionClassNames = extensionConfClassNames ++ loadCometExtension(sparkContext) |
| extensionConfClassNames.foreach { extensionConfClassName => |
| try { |
| val extensionConfClass = Utils.classForName(extensionConfClassName) |
| @@ -1396,4 +1405,12 @@ object SparkSession extends Logging { |
| } |
| } |
| } |
| + |
| + /** |
| + * Whether Comet extension is enabled |
| + */ |
| + def isCometEnabled: Boolean = { |
| + val v = System.getenv("ENABLE_COMET") |
| + v == null || v.toBoolean |
| + } |
| } |
| diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala |
| index db587dd9868..aac7295a53d 100644 |
| --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala |
| +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala |
| @@ -18,6 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.annotation.DeveloperApi |
| +import org.apache.spark.sql.comet.CometScanExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} |
| import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec |
| import org.apache.spark.sql.execution.exchange.ReusedExchangeExec |
| @@ -67,6 +68,7 @@ private[execution] object SparkPlanInfo { |
| // dump the file scan metadata (e.g file path) to event log |
| val metadata = plan match { |
| case fileScan: FileSourceScanExec => fileScan.metadata |
| + case cometScan: CometScanExec => cometScan.metadata |
| case _ => Map[String, String]() |
| } |
| new SparkPlanInfo( |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql |
| index 7aef901da4f..f3d6e18926d 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql |
| @@ -2,3 +2,4 @@ |
| |
| --SET spark.sql.adaptive.enabled=true |
| --SET spark.sql.maxMetadataStringLength = 500 |
| +--SET spark.comet.enabled = false |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql |
| index eeb2180f7a5..afd1b5ec289 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql |
| @@ -1,5 +1,6 @@ |
| --SET spark.sql.cbo.enabled=true |
| --SET spark.sql.maxMetadataStringLength = 500 |
| +--SET spark.comet.enabled = false |
| |
| CREATE TABLE explain_temp1(a INT, b INT) USING PARQUET; |
| CREATE TABLE explain_temp2(c INT, d INT) USING PARQUET; |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql |
| index 698ca009b4f..57d774a3617 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql |
| @@ -1,6 +1,7 @@ |
| --SET spark.sql.codegen.wholeStage = true |
| --SET spark.sql.adaptive.enabled = false |
| --SET spark.sql.maxMetadataStringLength = 500 |
| +--SET spark.comet.enabled = false |
| |
| -- Test tables |
| CREATE table explain_temp1 (key int, val int) USING PARQUET; |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql |
| index 1152d77da0c..f77493f690b 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql |
| @@ -7,6 +7,9 @@ |
| |
| -- avoid bit-exact output here because operations may not be bit-exact. |
| -- SET extra_float_digits = 0; |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| |
| -- Test aggregate operator with codegen on and off. |
| --CONFIG_DIM1 spark.sql.codegen.wholeStage=true |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql |
| index 41fd4de2a09..44cd244d3b0 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql |
| @@ -5,6 +5,9 @@ |
| -- AGGREGATES [Part 3] |
| -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L352-L605 |
| |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| -- Test aggregate operator with codegen on and off. |
| --CONFIG_DIM1 spark.sql.codegen.wholeStage=true |
| --CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql |
| index 3a409eea348..38fed024c98 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql |
| @@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1 |
| -- any evens |
| SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0'); |
| |
| +-- https://github.com/apache/datafusion-comet/issues/2215 |
| +--SET spark.comet.exec.enabled=false |
| -- [SPARK-28024] Incorrect value when out of range |
| SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i; |
| |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql |
| index fac23b4a26f..2b73732c33f 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql |
| @@ -1,6 +1,10 @@ |
| -- |
| -- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group |
| -- |
| + |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| -- |
| -- INT8 |
| -- Test int8 64-bit integers. |
| diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql |
| index 0efe0877e9b..423d3b3d76d 100644 |
| --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql |
| +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql |
| @@ -1,6 +1,10 @@ |
| -- |
| -- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group |
| -- |
| + |
| +-- Disable Comet exec due to floating point precision difference |
| +--SET spark.comet.exec.enabled = false |
| + |
| -- |
| -- SELECT_HAVING |
| -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala |
| index e5494726695..00937f025c2 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala |
| @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants |
| import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan} |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} |
| import org.apache.spark.sql.execution.columnar._ |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -519,7 +519,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils |
| df.collect() |
| } |
| assert( |
| - collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.size == expected) |
| + collect(df.queryExecution.executedPlan) { |
| + case _: ShuffleExchangeLike => 1 }.size == expected) |
| } |
| |
| test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala |
| index 9e8d77c53f3..855e3ada7d1 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala |
| @@ -790,7 +790,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { |
| } |
| } |
| |
| - test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { |
| + test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) { |
| withTempPath { dir => |
| val data = sparkContext.parallelize(0 to 10).toDF("id") |
| data.write.parquet(dir.getCanonicalPath) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala |
| index 6f3090d8908..c08a60fb0c2 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala |
| @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Expand |
| import org.apache.spark.sql.execution.WholeStageCodegenExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.expressions.Window |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -793,7 +793,7 @@ class DataFrameAggregateSuite extends QueryTest |
| assert(objHashAggPlans.nonEmpty) |
| |
| val exchangePlans = collect(aggPlan) { |
| - case shuffle: ShuffleExchangeExec => shuffle |
| + case shuffle: ShuffleExchangeLike => shuffle |
| } |
| assert(exchangePlans.length == 1) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala |
| index 56e9520fdab..917932336df 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala |
| @@ -435,7 +435,9 @@ class DataFrameJoinSuite extends QueryTest |
| |
| withTempDatabase { dbName => |
| withTable(table1Name, table2Name) { |
| - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| + withSQLConf( |
| + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| + "spark.comet.enabled" -> "false") { |
| spark.range(50).write.saveAsTable(s"$dbName.$table1Name") |
| spark.range(100).write.saveAsTable(s"$dbName.$table2Name") |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala |
| index 7ee18df3756..d09f70e5d99 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala |
| @@ -40,11 +40,12 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation |
| import org.apache.spark.sql.catalyst.parser.ParseException |
| import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| +import org.apache.spark.sql.comet.CometBroadcastExchangeExec |
| import org.apache.spark.sql.connector.FakeV2Provider |
| import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.aggregate.HashAggregateExec |
| -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} |
| +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.expressions.{Aggregator, Window} |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -2006,7 +2007,7 @@ class DataFrameSuite extends QueryTest |
| fail("Should not have back to back Aggregates") |
| } |
| atFirstAgg = true |
| - case e: ShuffleExchangeExec => atFirstAgg = false |
| + case e: ShuffleExchangeLike => atFirstAgg = false |
| case _ => |
| } |
| } |
| @@ -2330,7 +2331,7 @@ class DataFrameSuite extends QueryTest |
| checkAnswer(join, df) |
| assert( |
| collect(join.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => true }.size === 1) |
| + case _: ShuffleExchangeLike => true }.size === 1) |
| assert( |
| collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1) |
| val broadcasted = broadcast(join) |
| @@ -2338,10 +2339,12 @@ class DataFrameSuite extends QueryTest |
| checkAnswer(join2, df) |
| assert( |
| collect(join2.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => true }.size == 1) |
| + case _: ShuffleExchangeLike => true }.size == 1) |
| assert( |
| collect(join2.queryExecution.executedPlan) { |
| - case e: BroadcastExchangeExec => true }.size === 1) |
| + case e: BroadcastExchangeExec => true |
| + case _: CometBroadcastExchangeExec => true |
| + }.size === 1) |
| assert( |
| collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4) |
| } |
| @@ -2901,7 +2904,7 @@ class DataFrameSuite extends QueryTest |
| |
| // Assert that no extra shuffle introduced by cogroup. |
| val exchanges = collect(df3.queryExecution.executedPlan) { |
| - case h: ShuffleExchangeExec => h |
| + case h: ShuffleExchangeLike => h |
| } |
| assert(exchanges.size == 2) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala |
| index a1d5d579338..c201d39cc78 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala |
| @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression |
| import org.apache.spark.sql.catalyst.optimizer.TransposeWindow |
| import org.apache.spark.sql.catalyst.plans.logical.{Window => LogicalWindow} |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning |
| +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec} |
| +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.window.WindowExec |
| import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} |
| import org.apache.spark.sql.functions._ |
| @@ -1187,10 +1188,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest |
| } |
| |
| def isShuffleExecByRequirement( |
| - plan: ShuffleExchangeExec, |
| + plan: ShuffleExchangeLike, |
| desiredClusterColumns: Seq[String]): Boolean = plan match { |
| case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS, _) => |
| partitionExpressionsColumns(op.expressions) === desiredClusterColumns |
| + case CometShuffleExchangeExec(op: HashPartitioning, _, _, ENSURE_REQUIREMENTS, _, _) => |
| + partitionExpressionsColumns(op.expressions) === desiredClusterColumns |
| case _ => false |
| } |
| |
| @@ -1213,7 +1216,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest |
| val shuffleByRequirement = windowed.queryExecution.executedPlan.exists { |
| case w: WindowExec => |
| w.child.exists { |
| - case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2")) |
| + case s: ShuffleExchangeLike => isShuffleExecByRequirement(s, Seq("key1", "key2")) |
| case _ => false |
| } |
| case _ => false |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala |
| index c4fb4fa943c..a04b23870a8 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala |
| @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} |
| import org.apache.spark.sql.catalyst.util.sideBySide |
| import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} |
| +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.streaming.MemoryStream |
| import org.apache.spark.sql.expressions.UserDefinedFunction |
| import org.apache.spark.sql.functions._ |
| @@ -2288,7 +2288,7 @@ class DatasetSuite extends QueryTest |
| |
| // Assert that no extra shuffle introduced by cogroup. |
| val exchanges = collect(df3.queryExecution.executedPlan) { |
| - case h: ShuffleExchangeExec => h |
| + case h: ShuffleExchangeLike => h |
| } |
| assert(exchanges.size == 2) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala |
| index f33432ddb6f..42eb9fd1cb7 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala |
| @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen |
| import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} |
| import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ |
| import org.apache.spark.sql.catalyst.plans.ExistenceJoin |
| +import org.apache.spark.sql.comet.CometScanExec |
| import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog} |
| import org.apache.spark.sql.execution._ |
| import org.apache.spark.sql.execution.adaptive._ |
| @@ -262,6 +263,9 @@ abstract class DynamicPartitionPruningSuiteBase |
| case s: BatchScanExec => s.runtimeFilters.collect { |
| case d: DynamicPruningExpression => d.child |
| } |
| + case s: CometScanExec => s.partitionFilters.collect { |
| + case d: DynamicPruningExpression => d.child |
| + } |
| case _ => Nil |
| } |
| } |
| @@ -1027,7 +1031,8 @@ abstract class DynamicPartitionPruningSuiteBase |
| } |
| } |
| |
| - test("avoid reordering broadcast join keys to match input hash partitioning") { |
| + test("avoid reordering broadcast join keys to match input hash partitioning", |
| + IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| withTable("large", "dimTwo", "dimThree") { |
| @@ -1215,7 +1220,8 @@ abstract class DynamicPartitionPruningSuiteBase |
| } |
| |
| test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " + |
| - "canonicalization and exchange reuse") { |
| + "canonicalization and exchange reuse", |
| + IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val df = sql( |
| @@ -1423,7 +1429,8 @@ abstract class DynamicPartitionPruningSuiteBase |
| } |
| } |
| |
| - test("SPARK-34637: DPP side broadcast query stage is created firstly") { |
| + test("SPARK-34637: DPP side broadcast query stage is created firstly", |
| + IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { |
| val df = sql( |
| """ WITH v as ( |
| @@ -1698,7 +1705,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat |
| * Check the static scan metrics with and without DPP |
| */ |
| test("static scan metrics", |
| - DisableAdaptiveExecution("DPP in AQE must reuse broadcast")) { |
| + DisableAdaptiveExecution("DPP in AQE must reuse broadcast"), |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3313")) { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", |
| SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", |
| SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { |
| @@ -1729,6 +1737,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat |
| case s: BatchScanExec => |
| // we use f1 col for v2 tables due to schema pruning |
| s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) |
| + case s: CometScanExec => |
| + s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid"))) |
| case _ => false |
| } |
| assert(scanOption.isDefined) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala |
| index a206e97c353..79813d8e259 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala |
| @@ -280,7 +280,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite |
| } |
| } |
| |
| - test("explain formatted - check presence of subquery in case of DPP") { |
| + test("explain formatted - check presence of subquery in case of DPP", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3313")) { |
| withTable("df1", "df2") { |
| withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", |
| SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", |
| @@ -467,7 +468,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite |
| } |
| } |
| |
| - test("Explain formatted output for scan operator for datasource V2") { |
| + test("Explain formatted output for scan operator for datasource V2", |
| + IgnoreComet("Comet explain output is different")) { |
| withTempDir { dir => |
| Seq("parquet", "orc", "csv", "json").foreach { fmt => |
| val basePath = dir.getCanonicalPath + "/" + fmt |
| @@ -545,7 +547,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite |
| } |
| } |
| |
| -class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { |
| +// Ignored when Comet is enabled. Comet changes expected query plans. |
| +class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite |
| + with IgnoreCometSuite { |
| import testImplicits._ |
| |
| test("SPARK-35884: Explain Formatted") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala |
| index 93275487f29..510e3087e0f 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala |
| @@ -23,6 +23,7 @@ import java.nio.file.{Files, StandardOpenOption} |
| |
| import scala.collection.mutable |
| |
| +import org.apache.comet.CometConf |
| import org.apache.hadoop.conf.Configuration |
| import org.apache.hadoop.fs.{LocalFileSystem, Path} |
| |
| @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha |
| import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} |
| import org.apache.spark.sql.catalyst.plans.logical.Filter |
| import org.apache.spark.sql.catalyst.types.DataTypeUtils |
| +import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeScanExec, CometScanExec, CometSortMergeJoinExec} |
| import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.datasources.FilePartition |
| @@ -250,6 +252,8 @@ class FileBasedDataSourceSuite extends QueryTest |
| case "" => "_LEGACY_ERROR_TEMP_2062" |
| case _ => "_LEGACY_ERROR_TEMP_2055" |
| } |
| + // native_datafusion Parquet scan cannot throw a SparkFileNotFoundException |
| + assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) |
| checkErrorMatchPVals( |
| exception = intercept[SparkException] { |
| testIgnoreMissingFiles(options) |
| @@ -639,7 +643,8 @@ class FileBasedDataSourceSuite extends QueryTest |
| } |
| |
| Seq("parquet", "orc").foreach { format => |
| - test(s"Spark native readers should respect spark.sql.caseSensitive - ${format}") { |
| + test(s"Spark native readers should respect spark.sql.caseSensitive - ${format}", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| withTempDir { dir => |
| val tableName = s"spark_25132_${format}_native" |
| val tableDir = dir.getCanonicalPath + s"/$tableName" |
| @@ -955,6 +960,7 @@ class FileBasedDataSourceSuite extends QueryTest |
| assert(bJoinExec.isEmpty) |
| val smJoinExec = collect(joinedDF.queryExecution.executedPlan) { |
| case smJoin: SortMergeJoinExec => smJoin |
| + case smJoin: CometSortMergeJoinExec => smJoin |
| } |
| assert(smJoinExec.nonEmpty) |
| } |
| @@ -1015,6 +1021,7 @@ class FileBasedDataSourceSuite extends QueryTest |
| |
| val fileScan = df.queryExecution.executedPlan collectFirst { |
| case BatchScanExec(_, f: FileScan, _, _, _, _) => f |
| + case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _, _) => f |
| } |
| assert(fileScan.nonEmpty) |
| assert(fileScan.get.partitionFilters.nonEmpty) |
| @@ -1056,6 +1063,7 @@ class FileBasedDataSourceSuite extends QueryTest |
| |
| val fileScan = df.queryExecution.executedPlan collectFirst { |
| case BatchScanExec(_, f: FileScan, _, _, _, _) => f |
| + case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _, _) => f |
| } |
| assert(fileScan.nonEmpty) |
| assert(fileScan.get.partitionFilters.isEmpty) |
| @@ -1240,6 +1248,9 @@ class FileBasedDataSourceSuite extends QueryTest |
| val filters = df.queryExecution.executedPlan.collect { |
| case f: FileSourceScanLike => f.dataFilters |
| case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters |
| + case b: CometScanExec => b.dataFilters |
| + case b: CometNativeScanExec => b.dataFilters |
| + case b: CometBatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters |
| }.flatten |
| assert(filters.contains(GreaterThan(scan.logicalPlan.output.head, Literal(5L)))) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala |
| new file mode 100644 |
| index 00000000000..1ee842b6f62 |
| --- /dev/null |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala |
| @@ -0,0 +1,45 @@ |
| +/* |
| + * Licensed to the Apache Software Foundation (ASF) under one or more |
| + * contributor license agreements. See the NOTICE file distributed with |
| + * this work for additional information regarding copyright ownership. |
| + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| + * (the "License"); you may not use this file except in compliance with |
| + * the License. You may obtain a copy of the License at |
| + * |
| + * http://www.apache.org/licenses/LICENSE-2.0 |
| + * |
| + * Unless required by applicable law or agreed to in writing, software |
| + * distributed under the License is distributed on an "AS IS" BASIS, |
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| + * See the License for the specific language governing permissions and |
| + * limitations under the License. |
| + */ |
| + |
| +package org.apache.spark.sql |
| + |
| +import org.scalactic.source.Position |
| +import org.scalatest.Tag |
| + |
| +import org.apache.spark.sql.test.SQLTestUtils |
| + |
| +/** |
| + * Tests with this tag will be ignored when Comet is enabled (e.g., via `ENABLE_COMET`). |
| + */ |
| +case class IgnoreComet(reason: String) extends Tag("DisableComet") |
| +case class IgnoreCometNativeIcebergCompat(reason: String) extends Tag("DisableComet") |
| +case class IgnoreCometNativeDataFusion(reason: String) extends Tag("DisableComet") |
| +case class IgnoreCometNativeScan(reason: String) extends Tag("DisableComet") |
| + |
| +/** |
| + * Helper trait that disables Comet for all tests regardless of default config values. |
| + */ |
| +trait IgnoreCometSuite extends SQLTestUtils { |
| + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit |
| + pos: Position): Unit = { |
| + if (isCometEnabled) { |
| + ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun) |
| + } else { |
| + super.test(testName, testTags: _*)(testFun) |
| + } |
| + } |
| +} |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala |
| index 7af826583bd..3c3def1eb67 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala |
| @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide |
| import org.apache.spark.sql.catalyst.plans.PlanTest |
| import org.apache.spark.sql.catalyst.plans.logical._ |
| import org.apache.spark.sql.catalyst.rules.RuleExecutor |
| +import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.joins._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP |
| val executedPlan = df.queryExecution.executedPlan |
| val shuffleHashJoins = collect(executedPlan) { |
| case s: ShuffledHashJoinExec => s |
| + case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec] |
| } |
| assert(shuffleHashJoins.size == 1) |
| assert(shuffleHashJoins.head.buildSide == buildSide) |
| @@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP |
| val executedPlan = df.queryExecution.executedPlan |
| val shuffleMergeJoins = collect(executedPlan) { |
| case s: SortMergeJoinExec => s |
| + case c: CometSortMergeJoinExec => c |
| } |
| assert(shuffleMergeJoins.size == 1) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala |
| index 44c8cb92fc3..f098beeca26 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala |
| @@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation |
| import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} |
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} |
| import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION} |
| -import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} |
| +import org.apache.spark.sql.comet._ |
| +import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, FilterExec, InputAdapter, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.joins._ |
| @@ -802,7 +803,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| } |
| } |
| |
| - test("test SortMergeJoin (with spill)") { |
| + test("test SortMergeJoin (with spill)", |
| + IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) { |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", |
| SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", |
| SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { |
| @@ -928,10 +930,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val physical = df.queryExecution.sparkPlan |
| val physicalJoins = physical.collect { |
| case j: SortMergeJoinExec => j |
| + case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec] |
| } |
| val executed = df.queryExecution.executedPlan |
| val executedJoins = collect(executed) { |
| case j: SortMergeJoinExec => j |
| + case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec] |
| } |
| // This only applies to the above tested queries, in which a child SortMergeJoin always |
| // contains the SortOrder required by its parent SortMergeJoin. Thus, SortExec should never |
| @@ -1177,9 +1181,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) |
| .groupBy($"k1").count() |
| .queryExecution.executedPlan |
| - assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) |
| // No extra shuffle before aggregate |
| - assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2) |
| + assert(collect(plan) { |
| + case _: ShuffleExchangeLike => true }.size === 2) |
| }) |
| } |
| |
| @@ -1196,10 +1202,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) |
| .queryExecution |
| .executedPlan |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) |
| assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1) |
| // No extra sort before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 3) |
| + assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 3) |
| }) |
| |
| // Test shuffled hash join |
| @@ -1209,10 +1216,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) |
| .queryExecution |
| .executedPlan |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) |
| - assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) |
| + assert(collect(plan) { |
| + case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) |
| // No extra sort before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 3) |
| + assert(collect(plan) { |
| + case _: SortExec | _: CometSortExec => true }.size === 3) |
| }) |
| } |
| |
| @@ -1303,12 +1313,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| inputDFs.foreach { case (df1, df2, joinExprs) => |
| val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full") |
| assert(collect(smjDF.queryExecution.executedPlan) { |
| - case _: SortMergeJoinExec => true }.size === 1) |
| + case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 1) |
| val smjResult = smjDF.collect() |
| |
| val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full") |
| assert(collect(shjDF.queryExecution.executedPlan) { |
| - case _: ShuffledHashJoinExec => true }.size === 1) |
| + case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) |
| // Same result between shuffled hash join and sort merge join |
| checkAnswer(shjDF, smjResult) |
| } |
| @@ -1367,12 +1377,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter") |
| assert(collect(smjDF.queryExecution.executedPlan) { |
| case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| }.size === 1) |
| val smjResult = smjDF.collect() |
| |
| val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter") |
| assert(collect(shjDF.queryExecution.executedPlan) { |
| case _: ShuffledHashJoinExec => true |
| + case _: CometHashJoinExec => true |
| }.size === 1) |
| // Same result between shuffled hash join and sort merge join |
| checkAnswer(shjDF, smjResult) |
| @@ -1383,12 +1395,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter") |
| assert(collect(smjDF.queryExecution.executedPlan) { |
| case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| }.size === 1) |
| val smjResult = smjDF.collect() |
| |
| val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter") |
| assert(collect(shjDF.queryExecution.executedPlan) { |
| case _: ShuffledHashJoinExec => true |
| + case _: CometHashJoinExec => true |
| }.size === 1) |
| // Same result between shuffled hash join and sort merge join |
| checkAnswer(shjDF, smjResult) |
| @@ -1432,13 +1446,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| assert(shjCodegenDF.queryExecution.executedPlan.collect { |
| case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true |
| case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true |
| + case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometHashJoinExec))) => |
| + true |
| + case WholeStageCodegenExec(ColumnarToRowExec( |
| + InputAdapter(CometProjectExec(_, _, _, _, _: CometHashJoinExec, _)))) => true |
| + case _: CometHashJoinExec => true |
| }.size === 1) |
| checkAnswer(shjCodegenDF, Seq.empty) |
| |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { |
| val shjNonCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) |
| assert(shjNonCodegenDF.queryExecution.executedPlan.collect { |
| - case _: ShuffledHashJoinExec => true }.size === 1) |
| + case _: ShuffledHashJoinExec => true |
| + case _: CometHashJoinExec => true |
| + }.size === 1) |
| checkAnswer(shjNonCodegenDF, Seq.empty) |
| } |
| } |
| @@ -1486,7 +1507,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan |
| assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) |
| // Have shuffle before aggregation |
| - assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: ShuffleExchangeLike => true }.size === 1) |
| } |
| |
| def getJoinQuery(selectExpr: String, joinType: String): String = { |
| @@ -1515,9 +1537,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| } |
| val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan |
| assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| + }.size === 3) |
| // No extra sort on left side before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 5) |
| + assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 5) |
| } |
| |
| // Test output ordering is not preserved |
| @@ -1526,9 +1551,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0" |
| val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan |
| assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| + }.size === 3) |
| // Have sort on left side before last sort merge join |
| - assert(collect(plan) { case _: SortExec => true }.size === 6) |
| + assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 6) |
| } |
| |
| // Test singe partition |
| @@ -1538,7 +1566,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2 |
| |""".stripMargin) |
| val plan = fullJoinDF.queryExecution.executedPlan |
| - assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1) |
| + assert(collect(plan) { |
| + case _: ShuffleExchangeLike => true}.size == 1) |
| checkAnswer(fullJoinDF, Row(100)) |
| } |
| } |
| @@ -1611,6 +1640,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| Seq(semiJoinDF, antiJoinDF).foreach { df => |
| assert(collect(df.queryExecution.executedPlan) { |
| case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true |
| + case j: CometHashJoinExec |
| + if j.originalPlan.asInstanceOf[ShuffledHashJoinExec].ignoreDuplicatedKey == |
| + ignoreDuplicatedKey => true |
| }.size == 1) |
| } |
| } |
| @@ -1655,14 +1687,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |
| |
| test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") { |
| def check(plan: SparkPlan): Unit = { |
| - assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: SortMergeJoinExec => true |
| + case _: CometSortMergeJoinExec => true |
| + }.size === 1) |
| } |
| dupStreamSideColTest("MERGE", check) |
| } |
| |
| test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SHJ)") { |
| def check(plan: SparkPlan): Unit = { |
| - assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) |
| + assert(collect(plan) { |
| + case _: ShuffledHashJoinExec => true |
| + case _: CometHashJoinExec => true |
| + }.size === 1) |
| } |
| dupStreamSideColTest("SHUFFLE_HASH", check) |
| } |
| @@ -1798,7 +1836,8 @@ class ThreadLeakInSortMergeJoinSuite |
| sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20)) |
| } |
| |
| - test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)") { |
| + test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)", |
| + IgnoreComet("Comet does not support spilling")) { |
| |
| withSQLConf( |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala |
| index c26757c9cff..d55775f09d7 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala |
| @@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest |
| * }}} |
| */ |
| // scalastyle:on line.size.limit |
| -trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { |
| +trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite with IgnoreCometSuite { |
| |
| protected val baseResourcePath = { |
| // use the same way as `SQLQueryTestSuite` to get the resource path |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala |
| index 3cf2bfd17ab..49728c35c42 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala |
| @@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) |
| } |
| |
| - test("external sorting updates peak execution memory") { |
| + test("external sorting updates peak execution memory", |
| + IgnoreComet("TODO: native CometSort does not update peak execution memory")) { |
| AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { |
| sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() |
| } |
| @@ -4459,7 +4460,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| } |
| |
| test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" + |
| - " when WSCG is off") { |
| + " when WSCG is off", |
| + IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", |
| SQLConf.ANSI_ENABLED.key -> "true") { |
| withTable("t") { |
| @@ -4480,7 +4482,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| } |
| |
| test("SPARK-39175: Query context of Cast should be serialized to executors" + |
| - " when WSCG is off") { |
| + " when WSCG is off", |
| + IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", |
| SQLConf.ANSI_ENABLED.key -> "true") { |
| withTable("t") { |
| @@ -4497,14 +4500,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |
| val msg = intercept[SparkException] { |
| sql(query).collect() |
| }.getMessage |
| - assert(msg.contains(query)) |
| + if (!isCometEnabled) { |
| + // Comet's error message does not include the original SQL query |
| + // https://github.com/apache/datafusion-comet/issues/2215 |
| + assert(msg.contains(query)) |
| + } |
| } |
| } |
| } |
| } |
| |
| test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " + |
| - "be serialized to executors when WSCG is off") { |
| + "be serialized to executors when WSCG is off", |
| + IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) { |
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", |
| SQLConf.ANSI_ENABLED.key -> "true") { |
| withTable("t") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala |
| index fa1a64460fc..1d2e215d6a3 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala |
| @@ -17,6 +17,8 @@ |
| |
| package org.apache.spark.sql |
| |
| +import org.apache.comet.CometConf |
| + |
| import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkRuntimeException} |
| import org.apache.spark.sql.catalyst.expressions.Cast._ |
| import org.apache.spark.sql.execution.FormattedMode |
| @@ -178,29 +180,31 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { |
| } |
| |
| test("string regex_replace / regex_extract") { |
| - val df = Seq( |
| - ("100-200", "(\\d+)-(\\d+)", "300"), |
| - ("100-200", "(\\d+)-(\\d+)", "400"), |
| - ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") |
| + withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") { |
| + val df = Seq( |
| + ("100-200", "(\\d+)-(\\d+)", "300"), |
| + ("100-200", "(\\d+)-(\\d+)", "400"), |
| + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") |
| |
| - checkAnswer( |
| - df.select( |
| - regexp_replace($"a", "(\\d+)", "num"), |
| - regexp_replace($"a", $"b", $"c"), |
| - regexp_extract($"a", "(\\d+)-(\\d+)", 1)), |
| - Row("num-num", "300", "100") :: Row("num-num", "400", "100") :: |
| - Row("num-num", "400-400", "100") :: Nil) |
| - |
| - // for testing the mutable state of the expression in code gen. |
| - // This is a hack way to enable the codegen, thus the codegen is enable by default, |
| - // it will still use the interpretProjection if projection followed by a LocalRelation, |
| - // hence we add a filter operator. |
| - // See the optimizer rule `ConvertToLocalRelation` |
| - checkAnswer( |
| - df.filter("isnotnull(a)").selectExpr( |
| - "regexp_replace(a, b, c)", |
| - "regexp_extract(a, b, 1)"), |
| - Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) |
| + checkAnswer( |
| + df.select( |
| + regexp_replace($"a", "(\\d+)", "num"), |
| + regexp_replace($"a", $"b", $"c"), |
| + regexp_extract($"a", "(\\d+)-(\\d+)", 1)), |
| + Row("num-num", "300", "100") :: Row("num-num", "400", "100") :: |
| + Row("num-num", "400-400", "100") :: Nil) |
| + |
| + // for testing the mutable state of the expression in code gen. |
| + // This is a hack way to enable the codegen, thus the codegen is enable by default, |
| + // it will still use the interpretProjection if projection followed by a LocalRelation, |
| + // hence we add a filter operator. |
| + // See the optimizer rule `ConvertToLocalRelation` |
| + checkAnswer( |
| + df.filter("isnotnull(a)").selectExpr( |
| + "regexp_replace(a, b, c)", |
| + "regexp_extract(a, b, 1)"), |
| + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) |
| + } |
| } |
| |
| test("non-matching optional group") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala |
| index 04702201f82..5ee11f83ecf 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala |
| @@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer |
| import org.apache.spark.sql.catalyst.expressions.SubqueryExpression |
| import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} |
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort, Union} |
| +import org.apache.spark.sql.comet.CometScanExec |
| import org.apache.spark.sql.execution._ |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} |
| import org.apache.spark.sql.execution.datasources.FileScanRDD |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| @@ -1599,6 +1600,12 @@ class SubquerySuite extends QueryTest |
| fs.inputRDDs().forall( |
| _.asInstanceOf[FileScanRDD].filePartitions.forall( |
| _.files.forall(_.urlEncodedPath.contains("p=0")))) |
| + case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter( |
| + fs @ CometScanExec(_, _, _, _, partitionFilters, _, _, _, _, _, _)))) => |
| + partitionFilters.exists(ExecSubqueryExpression.hasSubquery) && |
| + fs.inputRDDs().forall( |
| + _.asInstanceOf[FileScanRDD].filePartitions.forall( |
| + _.files.forall(_.urlEncodedPath.contains("p=0")))) |
| case _ => false |
| }) |
| } |
| @@ -2164,7 +2171,7 @@ class SubquerySuite extends QueryTest |
| |
| df.collect() |
| val exchanges = collect(df.queryExecution.executedPlan) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(exchanges.size === 1) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala |
| index 9f8e979e3fb..3bc9dab8023 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala |
| @@ -87,7 +87,8 @@ class UDFSuite extends QueryTest with SharedSparkSession { |
| spark.catalog.dropTempView("tmp_table") |
| } |
| |
| - test("SPARK-8005 input_file_name") { |
| + test("SPARK-8005 input_file_name", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) { |
| withTempPath { dir => |
| val data = sparkContext.parallelize(0 to 10, 2).toDF("id") |
| data.write.parquet(dir.getCanonicalPath) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala |
| index d269290e616..13726a31e07 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala |
| @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._ |
| |
| import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} |
| import org.apache.spark.sql.catalyst.InternalRow |
| +import org.apache.spark.sql.comet.CometSortExec |
| import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} |
| import org.apache.spark.sql.connector.catalog.TableCapability._ |
| import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} |
| @@ -34,7 +35,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, |
| import org.apache.spark.sql.execution.SortExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} |
| -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} |
| +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector |
| import org.apache.spark.sql.expressions.Window |
| import org.apache.spark.sql.functions._ |
| @@ -269,13 +270,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS |
| val groupByColJ = df.groupBy($"j").agg(sum($"i")) |
| checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) |
| assert(collectFirst(groupByColJ.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined) |
| |
| val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*")) |
| checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) |
| assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined) |
| } |
| } |
| @@ -335,10 +336,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS |
| |
| val (shuffleExpected, sortExpected) = groupByExpects |
| assert(collectFirst(groupBy.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined === shuffleExpected) |
| assert(collectFirst(groupBy.queryExecution.executedPlan) { |
| case e: SortExec => e |
| + case c: CometSortExec => c |
| }.isDefined === sortExpected) |
| } |
| |
| @@ -353,10 +355,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS |
| |
| val (shuffleExpected, sortExpected) = windowFuncExpects |
| assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { |
| - case e: ShuffleExchangeExec => e |
| + case e: ShuffleExchangeLike => e |
| }.isDefined === shuffleExpected) |
| assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { |
| case e: SortExec => e |
| + case c: CometSortExec => c |
| }.isDefined === sortExpected) |
| } |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala |
| index cfc8b2cc845..b7c234e1437 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala |
| @@ -19,8 +19,9 @@ package org.apache.spark.sql.connector |
| import scala.collection.mutable.ArrayBuffer |
| |
| import org.apache.spark.SparkConf |
| -import org.apache.spark.sql.{AnalysisException, QueryTest} |
| +import org.apache.spark.sql.{AnalysisException, IgnoreCometNativeDataFusion, QueryTest} |
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan |
| +import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} |
| import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} |
| import org.apache.spark.sql.connector.read.ScanBuilder |
| import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} |
| @@ -152,7 +153,8 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { |
| } |
| } |
| |
| - test("Fallback Parquet V2 to V1") { |
| + test("Fallback Parquet V2 to V1", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3315")) { |
| Seq("parquet", classOf[ParquetDataSourceV2].getCanonicalName).foreach { format => |
| withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) { |
| val commands = ArrayBuffer.empty[(String, LogicalPlan)] |
| @@ -184,7 +186,11 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { |
| val df = spark.read.format(format).load(path.getCanonicalPath) |
| checkAnswer(df, inputData.toDF()) |
| assert( |
| - df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec])) |
| + df.queryExecution.executedPlan.exists { |
| + case _: FileSourceScanExec | _: CometScanExec | _: CometNativeScanExec => true |
| + case _ => false |
| + } |
| + ) |
| } |
| } finally { |
| spark.listenerManager.unregister(listener) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala |
| index 71e030f535e..d5ae6cbf3d5 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala |
| @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row} |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression} |
| import org.apache.spark.sql.catalyst.plans.physical |
| +import org.apache.spark.sql.comet.CometSortMergeJoinExec |
| import org.apache.spark.sql.connector.catalog.Identifier |
| import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog |
| import org.apache.spark.sql.connector.catalog.functions._ |
| @@ -31,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._ |
| import org.apache.spark.sql.execution.SparkPlan |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.joins.SortMergeJoinExec |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.internal.SQLConf._ |
| @@ -282,13 +283,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { |
| Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) |
| } |
| |
| - private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { |
| + private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = { |
| // here we skip collecting shuffle operators that are not associated with SMJ |
| collect(plan) { |
| case s: SortMergeJoinExec => s |
| + case c: CometSortMergeJoinExec => c.originalPlan |
| }.flatMap(smj => |
| collect(smj) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| }) |
| } |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala |
| index 12007cd94cd..07020f201fb 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala |
| @@ -21,7 +21,7 @@ package org.apache.spark.sql.connector |
| import java.sql.Date |
| import java.util.Collections |
| |
| -import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} |
| +import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, IgnoreCometSuite, Row} |
| import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} |
| import org.apache.spark.sql.catalyst.expressions.objects.Invoke |
| import org.apache.spark.sql.catalyst.plans.physical |
| @@ -45,7 +45,8 @@ import org.apache.spark.sql.util.QueryExecutionListener |
| import org.apache.spark.tags.SlowSQLTest |
| |
| @SlowSQLTest |
| -class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase { |
| +class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase |
| + with IgnoreCometSuite { |
| import testImplicits._ |
| |
| before { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala |
| index ae1c0a86a14..1d3b914fd64 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala |
| @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.permission.FsPermission |
| import org.mockito.Mockito.{mock, spy, when} |
| |
| import org.apache.spark._ |
| -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode} |
| +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, IgnoreComet, QueryTest, Row, SaveMode} |
| import org.apache.spark.sql.catalyst.FunctionIdentifier |
| import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator} |
| import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber} |
| @@ -256,7 +256,8 @@ class QueryExecutionErrorsSuite |
| } |
| |
| test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " + |
| - "compatibility with Spark 2.4/3.2 in reading/writing dates") { |
| + "compatibility with Spark 2.4/3.2 in reading/writing dates", |
| + IgnoreComet("Comet doesn't completely support datetime rebase mode yet")) { |
| |
| // Fail to read ancient datetime values. |
| withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala |
| index 418ca3430bb..eb8267192f8 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala |
| @@ -23,7 +23,7 @@ import scala.util.Random |
| import org.apache.hadoop.fs.Path |
| |
| import org.apache.spark.SparkConf |
| -import org.apache.spark.sql.{DataFrame, QueryTest} |
| +import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest} |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -195,7 +195,7 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest { |
| } |
| } |
| |
| - test("FileScan description") { |
| + test("FileScan description", IgnoreComet("Comet doesn't use BatchScan")) { |
| Seq("json", "orc", "parquet").foreach { format => |
| withTempPath { path => |
| val dir = path.getCanonicalPath |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala |
| index 743ec41dbe7..9f30d6c8e04 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala |
| @@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv |
| case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child) |
| case p: ProjectExec => isScanPlanTree(p.child) |
| case f: FilterExec => isScanPlanTree(f.child) |
| + // Comet produces scan plan tree like: |
| + // ColumnarToRow |
| + // +- ReusedExchange |
| + case _: ReusedExchangeExec => false |
| case _: LeafExecNode => true |
| case _ => false |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala |
| index de24b8c82b0..1f835481290 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala |
| @@ -18,7 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.rdd.RDD |
| -import org.apache.spark.sql.{execution, DataFrame, Row} |
| +import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row} |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions._ |
| import org.apache.spark.sql.catalyst.plans._ |
| @@ -35,7 +35,9 @@ import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.sql.types._ |
| |
| -class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { |
| +// Ignore this suite when Comet is enabled. This suite tests the Spark planner and Comet planner |
| +// comes out with too many difference. Simply ignoring this suite for now. |
| +class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper with IgnoreCometSuite { |
| import testImplicits._ |
| |
| setupTestData() |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala |
| index 9e9d717db3b..73de2b84938 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala |
| @@ -17,7 +17,10 @@ |
| |
| package org.apache.spark.sql.execution |
| |
| +import org.apache.comet.CometConf |
| + |
| import org.apache.spark.sql.{DataFrame, QueryTest, Row} |
| +import org.apache.spark.sql.comet.CometProjectExec |
| import org.apache.spark.sql.connector.SimpleWritableDataSource |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -34,7 +37,10 @@ abstract class RemoveRedundantProjectsSuiteBase |
| private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = { |
| withClue(df.queryExecution) { |
| val plan = df.queryExecution.executedPlan |
| - val actual = collectWithSubqueries(plan) { case p: ProjectExec => p }.size |
| + val actual = collectWithSubqueries(plan) { |
| + case p: ProjectExec => p |
| + case p: CometProjectExec => p |
| + }.size |
| assert(actual == expected) |
| } |
| } |
| @@ -134,12 +140,26 @@ abstract class RemoveRedundantProjectsSuiteBase |
| val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0") |
| df.collect() |
| val plan = df.queryExecution.executedPlan |
| - val numProjects = collectWithSubqueries(plan) { case p: ProjectExec => p }.length |
| + |
| + val numProjects = collectWithSubqueries(plan) { |
| + case p: ProjectExec => p |
| + case p: CometProjectExec => p |
| + }.length |
| |
| // Create a new plan that reverse the GenerateExec output and add a new ProjectExec between |
| // GenerateExec and its child. This is to test if the ProjectExec is removed, the output of |
| // the query will be incorrect. |
| - val newPlan = stripAQEPlan(plan) transform { |
| + |
| + // Comet-specific change to get original Spark plan before applying |
| + // a transformation to add a new ProjectExec |
| + var sparkPlan: SparkPlan = null |
| + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { |
| + val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0") |
| + df.collect() |
| + sparkPlan = df.queryExecution.executedPlan |
| + } |
| + |
| + val newPlan = stripAQEPlan(sparkPlan) transform { |
| case g @ GenerateExec(_, requiredChildOutput, _, _, child) => |
| g.copy(requiredChildOutput = requiredChildOutput.reverse, |
| child = ProjectExec(requiredChildOutput.reverse, child)) |
| @@ -151,6 +171,7 @@ abstract class RemoveRedundantProjectsSuiteBase |
| // The manually added ProjectExec node shouldn't be removed. |
| assert(collectWithSubqueries(newExecutedPlan) { |
| case p: ProjectExec => p |
| + case p: CometProjectExec => p |
| }.size == numProjects + 1) |
| |
| // Check the original plan's output and the new plan's output are the same. |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala |
| index 005e764cc30..92ec088efab 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala |
| @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution |
| |
| import org.apache.spark.sql.{DataFrame, QueryTest} |
| import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning} |
| +import org.apache.spark.sql.comet.CometSortExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.execution.joins.ShuffledJoin |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase |
| |
| private def checkNumSorts(df: DataFrame, count: Int): Unit = { |
| val plan = df.queryExecution.executedPlan |
| - assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count) |
| + assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count) |
| } |
| |
| private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala |
| index 47679ed7865..9ffbaecb98e 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala |
| @@ -18,6 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.sql.{DataFrame, QueryTest} |
| +import org.apache.spark.sql.comet.CometHashAggregateExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase |
| private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = { |
| val plan = df.queryExecution.executedPlan |
| assert(collectWithSubqueries(plan) { |
| - case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s |
| + case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s |
| }.length == hashAggCount) |
| assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala |
| index a1147c16cc8..c7a29496328 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala |
| @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution |
| |
| import org.apache.spark.{SparkArithmeticException, SparkException, SparkFileNotFoundException} |
| import org.apache.spark.sql._ |
| +import org.apache.spark.sql.IgnoreCometNativeDataFusion |
| import org.apache.spark.sql.catalyst.TableIdentifier |
| import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Divide} |
| import org.apache.spark.sql.catalyst.parser.ParseException |
| @@ -968,7 +969,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { |
| } |
| } |
| |
| - test("alter temporary view should follow current storeAnalyzedPlanForView config") { |
| + test("alter temporary view should follow current storeAnalyzedPlanForView config", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3314")) { |
| withTable("t") { |
| Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t") |
| withView("v1") { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala |
| index eec396b2e39..bf3f1c769d6 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala |
| @@ -18,7 +18,7 @@ |
| package org.apache.spark.sql.execution |
| |
| import org.apache.spark.TestUtils.assertSpilled |
| -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} |
| +import org.apache.spark.sql.{AnalysisException, IgnoreComet, QueryTest, Row} |
| import org.apache.spark.sql.internal.SQLConf.{WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD, WINDOW_EXEC_BUFFER_SPILL_THRESHOLD} |
| import org.apache.spark.sql.test.SharedSparkSession |
| |
| @@ -470,7 +470,7 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSparkSession { |
| Row(1, 3, null) :: Row(2, null, 4) :: Nil) |
| } |
| |
| - test("test with low buffer spill threshold") { |
| + test("test with low buffer spill threshold", IgnoreComet("Comet does not support spilling")) { |
| val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") |
| nums.createOrReplaceTempView("nums") |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala |
| index b14f4a405f6..90bed10eca9 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala |
| @@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} |
| import org.apache.spark.sql.catalyst.plans.logical.Deduplicate |
| +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometNativeColumnarToRowExec} |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| @@ -131,7 +132,11 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession { |
| spark.range(1).write.parquet(path.getAbsolutePath) |
| val df = spark.read.parquet(path.getAbsolutePath) |
| val columnarToRowExec = |
| - df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get |
| + df.queryExecution.executedPlan.collectFirst { |
| + case p: ColumnarToRowExec => p |
| + case p: CometColumnarToRowExec => p |
| + case p: CometNativeColumnarToRowExec => p |
| + }.get |
| try { |
| spark.range(1).foreach { _ => |
| columnarToRowExec.canonicalized |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala |
| index 5a413c77754..207b66e1d7b 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala |
| @@ -17,7 +17,7 @@ |
| |
| package org.apache.spark.sql.execution |
| |
| -import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} |
| +import org.apache.spark.sql.{Dataset, IgnoreCometSuite, QueryTest, Row, SaveMode} |
| import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode |
| import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} |
| import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite |
| @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.sql.types.{IntegerType, StringType, StructType} |
| |
| // Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec |
| -class WholeStageCodegenSuite extends QueryTest with SharedSparkSession |
| +class WholeStageCodegenSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite |
| with DisableAdaptiveExecutionSuite { |
| |
| import testImplicits._ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala |
| index 2f8e401e743..a4f94417dcc 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala |
| @@ -27,9 +27,11 @@ import org.scalatest.time.SpanSugar._ |
| import org.apache.spark.SparkException |
| import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} |
| import org.apache.spark.shuffle.sort.SortShuffleManager |
| -import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} |
| +import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row, SparkSession, Strategy} |
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} |
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} |
| +import org.apache.spark.sql.comet._ |
| +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec |
| import org.apache.spark.sql.execution._ |
| import org.apache.spark.sql.execution.aggregate.BaseAggregateExec |
| import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike} |
| @@ -117,6 +119,7 @@ class AdaptiveQueryExecSuite |
| private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { |
| collect(plan) { |
| case j: BroadcastHashJoinExec => j |
| + case j: CometBroadcastHashJoinExec => j.originalPlan.asInstanceOf[BroadcastHashJoinExec] |
| } |
| } |
| |
| @@ -129,36 +132,46 @@ class AdaptiveQueryExecSuite |
| private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { |
| collect(plan) { |
| case j: SortMergeJoinExec => j |
| + case j: CometSortMergeJoinExec => |
| + assert(j.originalPlan.isInstanceOf[SortMergeJoinExec]) |
| + j.originalPlan.asInstanceOf[SortMergeJoinExec] |
| } |
| } |
| |
| private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = { |
| collect(plan) { |
| case j: ShuffledHashJoinExec => j |
| + case j: CometHashJoinExec => j.originalPlan.asInstanceOf[ShuffledHashJoinExec] |
| } |
| } |
| |
| private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { |
| collect(plan) { |
| case j: BaseJoinExec => j |
| + case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] |
| + case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] |
| + case c: CometBroadcastHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] |
| } |
| } |
| |
| private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = { |
| collect(plan) { |
| case s: SortExec => s |
| + case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] |
| } |
| } |
| |
| private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { |
| collect(plan) { |
| case agg: BaseAggregateExec => agg |
| + case agg: CometHashAggregateExec => agg.originalPlan.asInstanceOf[BaseAggregateExec] |
| } |
| } |
| |
| private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = { |
| collect(plan) { |
| case l: CollectLimitExec => l |
| + case l: CometCollectLimitExec => l.originalPlan.asInstanceOf[CollectLimitExec] |
| } |
| } |
| |
| @@ -202,6 +215,7 @@ class AdaptiveQueryExecSuite |
| val parts = rdd.partitions |
| assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) |
| } |
| + |
| assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead)) |
| } |
| |
| @@ -210,7 +224,7 @@ class AdaptiveQueryExecSuite |
| val plan = df.queryExecution.executedPlan |
| assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) |
| val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffle.size == 1) |
| assert(shuffle(0).outputPartitioning.numPartitions == numPartition) |
| @@ -226,7 +240,8 @@ class AdaptiveQueryExecSuite |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| - checkNumLocalShuffleReads(adaptivePlan) |
| + // Comet shuffle changes shuffle metrics |
| + // checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| @@ -253,7 +268,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Reuse the parallelism of coalesced shuffle in local shuffle read") { |
| + test("Reuse the parallelism of coalesced shuffle in local shuffle read", |
| + IgnoreComet("Comet shuffle changes shuffle partition size")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| @@ -285,7 +301,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Reuse the default parallelism in local shuffle read") { |
| + test("Reuse the default parallelism in local shuffle read", |
| + IgnoreComet("Comet shuffle changes shuffle partition size")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| @@ -299,7 +316,8 @@ class AdaptiveQueryExecSuite |
| val localReads = collect(adaptivePlan) { |
| case read: AQEShuffleReadExec if read.isLocalRead => read |
| } |
| - assert(localReads.length == 2) |
| + // Comet shuffle changes shuffle metrics |
| + assert(localReads.length == 1) |
| val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] |
| val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] |
| // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 |
| @@ -324,7 +342,9 @@ class AdaptiveQueryExecSuite |
| .groupBy($"a").count() |
| checkAnswer(testDf, Seq()) |
| val plan = testDf.queryExecution.executedPlan |
| - assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) |
| + assert(find(plan) { case p => |
| + p.isInstanceOf[SortMergeJoinExec] || p.isInstanceOf[CometSortMergeJoinExec] |
| + }.isDefined) |
| val coalescedReads = collect(plan) { |
| case r: AQEShuffleReadExec => r |
| } |
| @@ -338,7 +358,9 @@ class AdaptiveQueryExecSuite |
| .groupBy($"a").count() |
| checkAnswer(testDf, Seq()) |
| val plan = testDf.queryExecution.executedPlan |
| - assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) |
| + assert(find(plan) { case p => |
| + p.isInstanceOf[BroadcastHashJoinExec] || p.isInstanceOf[CometBroadcastHashJoinExec] |
| + }.isDefined) |
| val coalescedReads = collect(plan) { |
| case r: AQEShuffleReadExec => r |
| } |
| @@ -348,7 +370,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Scalar subquery") { |
| + test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -363,7 +385,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Scalar subquery in later stages") { |
| + test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -379,7 +401,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("multiple joins") { |
| + test("multiple joins", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -424,7 +446,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("multiple joins with aggregate") { |
| + test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -469,7 +491,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("multiple joins with aggregate 2") { |
| + test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { |
| @@ -515,7 +537,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Exchange reuse") { |
| + test("Exchange reuse", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -534,7 +556,7 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Exchange reuse with subqueries") { |
| + test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| @@ -565,7 +587,9 @@ class AdaptiveQueryExecSuite |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| - checkNumLocalShuffleReads(adaptivePlan) |
| + // Comet shuffle changes shuffle metrics, |
| + // so we can't check the number of local shuffle reads. |
| + // checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.nonEmpty) |
| @@ -586,7 +610,9 @@ class AdaptiveQueryExecSuite |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| - checkNumLocalShuffleReads(adaptivePlan) |
| + // Comet shuffle changes shuffle metrics, |
| + // so we can't check the number of local shuffle reads. |
| + // checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.isEmpty) |
| @@ -595,7 +621,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("Broadcast exchange reuse across subqueries") { |
| + test("Broadcast exchange reuse across subqueries", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", |
| @@ -690,7 +717,8 @@ class AdaptiveQueryExecSuite |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| // There is still a SMJ, and its two shuffles can't apply local read. |
| - checkNumLocalShuffleReads(adaptivePlan, 2) |
| + // Comet shuffle changes shuffle metrics |
| + // checkNumLocalShuffleReads(adaptivePlan, 2) |
| } |
| } |
| |
| @@ -812,7 +840,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-29544: adaptive skew join with different join types") { |
| + test("SPARK-29544: adaptive skew join with different join types", |
| + IgnoreComet("Comet shuffle has different partition metrics")) { |
| Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => |
| def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") { |
| findTopLevelSortMergeJoin(plan) |
| @@ -1030,7 +1059,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("metrics of the shuffle read") { |
| + test("metrics of the shuffle read", |
| + IgnoreComet("Comet shuffle changes the metrics")) { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT key FROM testData GROUP BY key") |
| @@ -1625,7 +1655,7 @@ class AdaptiveQueryExecSuite |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") |
| assert(collect(adaptivePlan) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| }.length == 1) |
| } |
| } |
| @@ -1705,7 +1735,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-33551: Do not use AQE shuffle read for repartition") { |
| + test("SPARK-33551: Do not use AQE shuffle read for repartition", |
| + IgnoreComet("Comet shuffle changes partition size")) { |
| def hasRepartitionShuffle(plan: SparkPlan): Boolean = { |
| find(plan) { |
| case s: ShuffleExchangeLike => |
| @@ -1890,6 +1921,9 @@ class AdaptiveQueryExecSuite |
| def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = { |
| assert(collect(ds.queryExecution.executedPlan) { |
| case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s |
| + case c: CometShuffleExchangeExec |
| + if c.originalPlan.shuffleOrigin == origin && |
| + c.originalPlan.numPartitions == 2 => c |
| }.size == 1) |
| ds.collect() |
| val plan = ds.queryExecution.executedPlan |
| @@ -1898,6 +1932,9 @@ class AdaptiveQueryExecSuite |
| }.isEmpty) |
| assert(collect(plan) { |
| case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s |
| + case c: CometShuffleExchangeExec |
| + if c.originalPlan.shuffleOrigin == origin && |
| + c.originalPlan.numPartitions == 2 => c |
| }.size == 1) |
| checkAnswer(ds, testData) |
| } |
| @@ -2054,7 +2091,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-35264: Support AQE side shuffled hash join formula") { |
| + test("SPARK-35264: Support AQE side shuffled hash join formula", |
| + IgnoreComet("Comet shuffle changes the partition size")) { |
| withTempView("t1", "t2") { |
| def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = { |
| Seq("100", "100000").foreach { size => |
| @@ -2140,7 +2178,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") { |
| + test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withTempView("v") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| @@ -2239,7 +2278,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + |
| s"JOIN skewData2 ON key1 = key2 GROUP BY key1") |
| val shuffles1 = collect(adaptive1) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffles1.size == 3) |
| // shuffles1.head is the top-level shuffle under the Aggregate operator |
| @@ -2252,7 +2291,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + |
| s"JOIN skewData2 ON key1 = key2") |
| val shuffles2 = collect(adaptive2) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| if (hasRequiredDistribution) { |
| assert(shuffles2.size == 3) |
| @@ -2286,7 +2325,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-35794: Allow custom plugin for cost evaluator") { |
| + test("SPARK-35794: Allow custom plugin for cost evaluator", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| CostEvaluator.instantiate( |
| classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) |
| intercept[IllegalArgumentException] { |
| @@ -2417,7 +2457,8 @@ class AdaptiveQueryExecSuite |
| } |
| |
| test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " + |
| - "resulting in potentially inaccurate data") { |
| + "resulting in potentially inaccurate data", |
| + IgnoreComet("https://github.com/apache/datafusion-comet/issues/1501")) { |
| withTable("t3") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| @@ -2452,6 +2493,7 @@ class AdaptiveQueryExecSuite |
| val (_, adaptive) = runAdaptiveAndVerifyResult(query) |
| assert(adaptive.collect { |
| case sort: SortExec => sort |
| + case sort: CometSortExec => sort |
| }.size == 1) |
| val read = collect(adaptive) { |
| case read: AQEShuffleReadExec => read |
| @@ -2469,7 +2511,8 @@ class AdaptiveQueryExecSuite |
| } |
| } |
| |
| - test("SPARK-37357: Add small partition factor for rebalance partitions") { |
| + test("SPARK-37357: Add small partition factor for rebalance partitions", |
| + IgnoreComet("Comet shuffle changes shuffle metrics")) { |
| withTempView("v") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", |
| @@ -2581,7 +2624,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "JOIN skewData3 ON value2 = value3") |
| val shuffles1 = collect(adaptive1) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffles1.size == 4) |
| val smj1 = findTopLevelSortMergeJoin(adaptive1) |
| @@ -2592,7 +2635,7 @@ class AdaptiveQueryExecSuite |
| runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "JOIN skewData3 ON value1 = value3") |
| val shuffles2 = collect(adaptive2) { |
| - case s: ShuffleExchangeExec => s |
| + case s: ShuffleExchangeLike => s |
| } |
| assert(shuffles2.size == 4) |
| val smj2 = findTopLevelSortMergeJoin(adaptive2) |
| @@ -2850,6 +2893,7 @@ class AdaptiveQueryExecSuite |
| }.size == (if (firstAccess) 1 else 0)) |
| assert(collect(initialExecutedPlan) { |
| case s: SortExec => s |
| + case s: CometSortExec => s |
| }.size == (if (firstAccess) 2 else 0)) |
| assert(collect(initialExecutedPlan) { |
| case i: InMemoryTableScanLike => i |
| @@ -2980,7 +3024,9 @@ class AdaptiveQueryExecSuite |
| |
| val plan = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] |
| assert(plan.inputPlan.isInstanceOf[TakeOrderedAndProjectExec]) |
| - assert(plan.finalPhysicalPlan.isInstanceOf[WindowExec]) |
| + assert( |
| + plan.finalPhysicalPlan.isInstanceOf[WindowExec] || |
| + plan.finalPhysicalPlan.find(_.isInstanceOf[CometWindowExec]).nonEmpty) |
| plan.inputPlan.output.zip(plan.finalPhysicalPlan.output).foreach { case (o1, o2) => |
| assert(o1.semanticEquals(o2), "Different output column order after AQE optimization") |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala |
| index fd52d038ca6..154c800be67 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala |
| @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Concat |
| import org.apache.spark.sql.catalyst.parser.CatalystSqlParser |
| import org.apache.spark.sql.catalyst.plans.logical.Expand |
| import org.apache.spark.sql.catalyst.types.DataTypeUtils |
| +import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} |
| import org.apache.spark.sql.execution.FileSourceScanExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.functions._ |
| @@ -884,6 +885,8 @@ abstract class SchemaPruningSuite |
| val fileSourceScanSchemata = |
| collect(df.queryExecution.executedPlan) { |
| case scan: FileSourceScanExec => scan.requiredSchema |
| + case scan: CometScanExec => scan.requiredSchema |
| + case scan: CometNativeScanExec => scan.requiredSchema |
| } |
| assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, |
| s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala |
| index 5fd27410dcb..468abb1543a 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala |
| @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources |
| import org.apache.spark.sql.{QueryTest, Row} |
| import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder} |
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} |
| +import org.apache.spark.sql.comet.CometSortExec |
| import org.apache.spark.sql.execution.{QueryExecution, SortExec} |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -243,6 +244,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write |
| // assert the outer most sort in the executed plan |
| assert(plan.collectFirst { |
| case s: SortExec => s |
| + case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] |
| }.exists { |
| case SortExec(Seq( |
| SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), |
| @@ -290,6 +292,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write |
| // assert the outer most sort in the executed plan |
| assert(plan.collectFirst { |
| case s: SortExec => s |
| + case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] |
| }.exists { |
| case SortExec(Seq( |
| SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala |
| index 0b6fdef4f74..5b18c55da4b 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala |
| @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path} |
| import org.mockito.Mockito.{mock, when} |
| |
| import org.apache.spark.SparkException |
| -import org.apache.spark.sql.{DataFrame, QueryTest, Row} |
| +import org.apache.spark.sql.{DataFrame, IgnoreCometSuite, QueryTest, Row} |
| import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder |
| import org.apache.spark.sql.execution.datasources.PartitionedFile |
| import org.apache.spark.sql.functions.col |
| @@ -38,7 +38,9 @@ import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.sql.types._ |
| import org.apache.spark.util.Utils |
| |
| -class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { |
| +// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow. |
| +// Since it isn't related to Comet, we disable it for now. |
| +class BinaryFileFormatSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite { |
| import BinaryFileFormat._ |
| |
| private var testDir: String = _ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala |
| index 07e2849ce6f..3e73645b638 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala |
| @@ -28,7 +28,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat |
| |
| import org.apache.spark.TestUtils |
| import org.apache.spark.memory.MemoryMode |
| -import org.apache.spark.sql.Row |
| +import org.apache.spark.sql.{IgnoreComet, Row} |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SharedSparkSession |
| @@ -201,7 +201,8 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess |
| } |
| } |
| |
| - test("parquet v2 pages - rle encoding for boolean value columns") { |
| + test("parquet v2 pages - rle encoding for boolean value columns", |
| + IgnoreComet("Comet doesn't support RLE encoding yet")) { |
| val extraOptions = Map[String, String]( |
| ParquetOutputFormat.WRITER_VERSION -> ParquetProperties.WriterVersion.PARQUET_2_0.toString |
| ) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala |
| index 8e88049f51e..49f2001dc6b 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala |
| @@ -1095,7 +1095,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| // When a filter is pushed to Parquet, Parquet can apply it to every row. |
| // So, we can check the number of rows returned from the Parquet |
| // to make sure our filter pushdown work. |
| - assert(stripSparkFilter(df).count == 1) |
| + // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies |
| + // on Spark to apply the data filters after columnar batches are returned |
| + if (!isCometEnabled) { |
| + assert(stripSparkFilter(df).count == 1) |
| + } |
| } |
| } |
| } |
| @@ -1498,7 +1502,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("Filters should be pushed down for vectorized Parquet reader at row group level") { |
| + test("Filters should be pushed down for vectorized Parquet reader at row group level", |
| + IgnoreCometNativeScan("Native scans do not support the tested accumulator")) { |
| import testImplicits._ |
| |
| withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", |
| @@ -1548,7 +1553,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("SPARK-31026: Parquet predicate pushdown for fields having dots in the names") { |
| + test("SPARK-31026: Parquet predicate pushdown for fields having dots in the names", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3320")) { |
| import testImplicits._ |
| |
| withAllParquetReaders { |
| @@ -1580,13 +1586,18 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| // than the total length but should not be a single record. |
| // Note that, if record level filtering is enabled, it should be a single record. |
| // If no filter is pushed down to Parquet, it should be the total length of data. |
| - assert(actual > 1 && actual < data.length) |
| + // Only enable Comet test iff it's scan only, since with native execution |
| + // `stripSparkFilter` can't remove the native filter |
| + if (!isCometEnabled || isCometScanOnly) { |
| + assert(actual > 1 && actual < data.length) |
| + } |
| } |
| } |
| } |
| } |
| |
| - test("Filters should be pushed down for Parquet readers at row group level") { |
| + test("Filters should be pushed down for Parquet readers at row group level", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3320")) { |
| import testImplicits._ |
| |
| withSQLConf( |
| @@ -1607,7 +1618,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| // than the total length but should not be a single record. |
| // Note that, if record level filtering is enabled, it should be a single record. |
| // If no filter is pushed down to Parquet, it should be the total length of data. |
| - assert(actual > 1 && actual < data.length) |
| + // Only enable Comet test iff it's scan only, since with native execution |
| + // `stripSparkFilter` can't remove the native filter |
| + if (!isCometEnabled || isCometScanOnly) { |
| + assert(actual > 1 && actual < data.length) |
| + } |
| } |
| } |
| } |
| @@ -1699,7 +1714,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| (attr, value) => sources.StringContains(attr, value)) |
| } |
| |
| - test("filter pushdown - StringPredicate") { |
| + test("filter pushdown - StringPredicate", IgnoreCometNativeScan("cannot be pushed down")) { |
| import testImplicits._ |
| // keep() should take effect on StartsWith/EndsWith/Contains |
| Seq( |
| @@ -1743,7 +1758,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { |
| + test("SPARK-17091: Convert IN predicate to Parquet filter push-down", |
| + IgnoreCometNativeScan("Comet has different push-down behavior")) { |
| val schema = StructType(Seq( |
| StructField("a", IntegerType, nullable = false) |
| )) |
| @@ -1933,7 +1949,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("SPARK-25207: exception when duplicate fields in case-insensitive mode") { |
| + test("SPARK-25207: exception when duplicate fields in case-insensitive mode", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| withTempPath { dir => |
| val count = 10 |
| val tableName = "spark_25207" |
| @@ -1984,7 +2001,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("Support Parquet column index") { |
| + test("Support Parquet column index", |
| + IgnoreComet("Comet doesn't support Parquet column index yet")) { |
| // block 1: |
| // null count min max |
| // page-0 0 0 99 |
| @@ -2044,7 +2062,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
| } |
| } |
| |
| - test("SPARK-34562: Bloom filter push down") { |
| + test("SPARK-34562: Bloom filter push down", |
| + IgnoreCometNativeScan("Native scans do not support the tested accumulator")) { |
| withTempPath { dir => |
| val path = dir.getCanonicalPath |
| spark.range(100).selectExpr("id * 2 AS id") |
| @@ -2276,7 +2295,11 @@ class ParquetV1FilterSuite extends ParquetFilterSuite { |
| assert(pushedParquetFilters.exists(_.getClass === filterClass), |
| s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.") |
| |
| - checker(stripSparkFilter(query), expected) |
| + // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies |
| + // on Spark to apply the data filters after columnar batches are returned |
| + if (!isCometEnabled) { |
| + checker(stripSparkFilter(query), expected) |
| + } |
| } else { |
| assert(selectedFilters.isEmpty, "There is filter pushed down") |
| } |
| @@ -2336,7 +2359,11 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { |
| assert(pushedParquetFilters.exists(_.getClass === filterClass), |
| s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.") |
| |
| - checker(stripSparkFilter(query), expected) |
| + // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies |
| + // on Spark to apply the data filters after columnar batches are returned |
| + if (!isCometEnabled) { |
| + checker(stripSparkFilter(query), expected) |
| + } |
| |
| case _ => |
| throw new AnalysisException("Can not match ParquetTable in the query.") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala |
| index 8ed9ef1630e..f312174b182 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala |
| @@ -1064,7 +1064,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession |
| } |
| } |
| |
| - test("SPARK-35640: read binary as timestamp should throw schema incompatible error") { |
| + test("SPARK-35640: read binary as timestamp should throw schema incompatible error", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| val data = (1 to 4).map(i => Tuple1(i.toString)) |
| val readSchema = StructType(Seq(StructField("_1", DataTypes.TimestampType))) |
| |
| @@ -1075,7 +1076,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession |
| } |
| } |
| |
| - test("SPARK-35640: int as long should throw schema incompatible error") { |
| + test("SPARK-35640: int as long should throw schema incompatible error", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| val data = (1 to 4).map(i => Tuple1(i)) |
| val readSchema = StructType(Seq(StructField("_1", DataTypes.LongType))) |
| |
| @@ -1345,7 +1347,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession |
| } |
| } |
| |
| - test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings") { |
| + test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings", |
| + IgnoreComet("Comet doesn't support DELTA encoding yet")) { |
| withAllParquetReaders { |
| checkAnswer( |
| // "fruit" column in this file is encoded using DELTA_LENGTH_BYTE_ARRAY. |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala |
| index f6472ba3d9d..ce39ebb52e6 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala |
| @@ -185,7 +185,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| } |
| } |
| |
| - test("SPARK-36182: can't read TimestampLTZ as TimestampNTZ") { |
| + test("SPARK-36182: can't read TimestampLTZ as TimestampNTZ", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| val data = (1 to 1000).map { i => |
| val ts = new java.sql.Timestamp(i) |
| Row(ts) |
| @@ -998,7 +999,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| } |
| } |
| |
| - test("SPARK-26677: negated null-safe equality comparison should not filter matched row groups") { |
| + test("SPARK-26677: negated null-safe equality comparison should not filter matched row groups", |
| + IgnoreCometNativeScan("Native scans had the filter pushed into DF operator, cannot strip")) { |
| withAllParquetReaders { |
| withTempPath { path => |
| // Repeated values for dictionary encoding. |
| @@ -1051,7 +1053,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| testMigration(fromTsType = "TIMESTAMP_MICROS", toTsType = "INT96") |
| } |
| |
| - test("SPARK-34212 Parquet should read decimals correctly") { |
| + test("SPARK-34212 Parquet should read decimals correctly", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| def readParquet(schema: String, path: File): DataFrame = { |
| spark.read.schema(schema).parquet(path.toString) |
| } |
| @@ -1067,7 +1070,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| checkAnswer(readParquet(schema, path), df) |
| } |
| |
| - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { |
| + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", |
| + "spark.comet.enabled" -> "false") { |
| val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)" |
| checkAnswer(readParquet(schema1, path), df) |
| val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)" |
| @@ -1089,7 +1093,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| val df = sql(s"SELECT 1 a, 123456 b, ${Int.MaxValue.toLong * 10} c, CAST('1.2' AS BINARY) d") |
| df.write.parquet(path.toString) |
| |
| - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { |
| + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", |
| + "spark.comet.enabled" -> "false") { |
| checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00")) |
| checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null)) |
| checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 123456.0")) |
| @@ -1133,7 +1138,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| } |
| } |
| |
| - test("row group skipping doesn't overflow when reading into larger type") { |
| + test("row group skipping doesn't overflow when reading into larger type", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| withTempPath { path => |
| Seq(0).toDF("a").write.parquet(path.toString) |
| // The vectorized and non-vectorized readers will produce different exceptions, we don't need |
| @@ -1148,7 +1154,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS |
| .where(s"a < ${Long.MaxValue}") |
| .collect() |
| } |
| - assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) |
| + assert(exception.getMessage.contains("Column: [a], Expected: bigint, Found: INT32")) |
| } |
| } |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala |
| index 4f906411345..6cc69f7e915 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala |
| @@ -21,7 +21,7 @@ import java.nio.file.{Files, Paths, StandardCopyOption} |
| import java.sql.{Date, Timestamp} |
| |
| import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException} |
| -import org.apache.spark.sql.{QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY} |
| +import org.apache.spark.sql.{IgnoreCometSuite, QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY} |
| import org.apache.spark.sql.catalyst.util.DateTimeTestUtils |
| import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} |
| import org.apache.spark.sql.internal.LegacyBehaviorPolicy.{CORRECTED, EXCEPTION, LEGACY} |
| @@ -30,9 +30,11 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType.{INT96, |
| import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.tags.SlowSQLTest |
| |
| +// Comet is disabled for this suite because it doesn't support datetime rebase mode |
| abstract class ParquetRebaseDatetimeSuite |
| extends QueryTest |
| with ParquetTest |
| + with IgnoreCometSuite |
| with SharedSparkSession { |
| |
| import testImplicits._ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala |
| index 27c2a2148fd..df04a15fb1f 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala |
| @@ -20,12 +20,14 @@ import java.io.File |
| |
| import scala.collection.JavaConverters._ |
| |
| +import org.apache.comet.CometConf |
| import org.apache.hadoop.fs.Path |
| import org.apache.parquet.column.ParquetProperties._ |
| import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat} |
| import org.apache.parquet.hadoop.ParquetWriter.DEFAULT_BLOCK_SIZE |
| |
| import org.apache.spark.sql.QueryTest |
| +import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec} |
| import org.apache.spark.sql.execution.FileSourceScanExec |
| import org.apache.spark.sql.execution.datasources.FileFormat |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| @@ -172,6 +174,8 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { |
| |
| private def testRowIndexGeneration(label: String, conf: RowIndexTestConf): Unit = { |
| test (s"$label - ${conf.desc}") { |
| + // native_datafusion Parquet scan does not support row index generation. |
| + assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) |
| withSQLConf(conf.sqlConfs: _*) { |
| withTempPath { path => |
| // Read row index using _metadata.row_index if that is supported by the file format. |
| @@ -243,6 +247,12 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { |
| case f: FileSourceScanExec => |
| numPartitions += f.inputRDD.partitions.length |
| numOutputRows += f.metrics("numOutputRows").value |
| + case b: CometScanExec => |
| + numPartitions += b.inputRDD.partitions.length |
| + numOutputRows += b.metrics("numOutputRows").value |
| + case b: CometBatchScanExec => |
| + numPartitions += b.inputRDD.partitions.length |
| + numOutputRows += b.metrics("numOutputRows").value |
| case _ => |
| } |
| assert(numPartitions > 0) |
| @@ -301,6 +311,8 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { |
| val conf = RowIndexTestConf(useDataSourceV2 = useDataSourceV2) |
| |
| test(s"invalid row index column type - ${conf.desc}") { |
| + // native_datafusion Parquet scan does not support row index generation. |
| + assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) |
| withSQLConf(conf.sqlConfs: _*) { |
| withTempPath{ path => |
| val df = spark.range(0, 10, 1, 1).toDF("id") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala |
| index 5c0b7def039..151184bc98c 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala |
| @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet |
| import org.apache.spark.SparkConf |
| import org.apache.spark.sql.DataFrame |
| import org.apache.spark.sql.catalyst.parser.CatalystSqlParser |
| +import org.apache.spark.sql.comet.CometBatchScanExec |
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper |
| import org.apache.spark.sql.execution.datasources.SchemaPruningSuite |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| @@ -56,6 +57,7 @@ class ParquetV2SchemaPruningSuite extends ParquetSchemaPruningSuite { |
| val fileSourceScanSchemata = |
| collect(df.queryExecution.executedPlan) { |
| case scan: BatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema |
| + case scan: CometBatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema |
| } |
| assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, |
| s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala |
| index 3f47c5e506f..92a5eafec84 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala |
| @@ -27,6 +27,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName |
| import org.apache.parquet.schema.Type._ |
| |
| import org.apache.spark.SparkException |
| +import org.apache.spark.sql.{IgnoreComet, IgnoreCometNativeDataFusion} |
| import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType |
| import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException |
| import org.apache.spark.sql.functions.desc |
| @@ -1036,7 +1037,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |
| e |
| } |
| |
| - test("schema mismatch failure error message for parquet reader") { |
| + test("schema mismatch failure error message for parquet reader", |
| + IgnoreComet("Comet doesn't work with vectorizedReaderEnabled = false")) { |
| withTempPath { dir => |
| val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false) |
| val expectedMessage = "Encountered error while reading file" |
| @@ -1046,7 +1048,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |
| } |
| } |
| |
| - test("schema mismatch failure error message for parquet vectorized reader") { |
| + test("schema mismatch failure error message for parquet vectorized reader", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| withTempPath { dir => |
| val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = true) |
| assert(e.getCause.isInstanceOf[SparkException]) |
| @@ -1087,7 +1090,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |
| } |
| } |
| |
| - test("SPARK-45604: schema mismatch failure error on timestamp_ntz to array<timestamp_ntz>") { |
| + test("SPARK-45604: schema mismatch failure error on timestamp_ntz to array<timestamp_ntz>", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3311")) { |
| import testImplicits._ |
| |
| withTempPath { dir => |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala |
| index b8f3ea3c6f3..bbd44221288 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala |
| @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.debug |
| import java.io.ByteArrayOutputStream |
| |
| import org.apache.spark.rdd.RDD |
| +import org.apache.spark.sql.IgnoreComet |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.Attribute |
| import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext |
| @@ -125,7 +126,8 @@ class DebuggingSuite extends DebuggingSuiteBase with DisableAdaptiveExecutionSui |
| | id LongType: {}""".stripMargin)) |
| } |
| |
| - test("SPARK-28537: DebugExec cannot debug columnar related queries") { |
| + test("SPARK-28537: DebugExec cannot debug columnar related queries", |
| + IgnoreComet("Comet does not use FileScan")) { |
| withTempPath { workDir => |
| val workDirPath = workDir.getAbsolutePath |
| val input = spark.range(5).toDF("id") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala |
| index 5cdbdc27b32..307fba16578 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala |
| @@ -46,8 +46,10 @@ import org.apache.spark.sql.util.QueryExecutionListener |
| import org.apache.spark.util.{AccumulatorContext, JsonProtocol} |
| |
| // Disable AQE because metric info is different with AQE on/off |
| +// This test suite runs tests against the metrics of physical operators. |
| +// Disabling it for Comet because the metrics are different with Comet enabled. |
| class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils |
| - with DisableAdaptiveExecutionSuite { |
| + with DisableAdaptiveExecutionSuite with IgnoreCometSuite { |
| import testImplicits._ |
| |
| /** |
| @@ -765,7 +767,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils |
| } |
| } |
| |
| - test("SPARK-26327: FileSourceScanExec metrics") { |
| + test("SPARK-26327: FileSourceScanExec metrics", |
| + IgnoreComet("Spark uses row-based Parquet reader while Comet is vectorized")) { |
| withTable("testDataForScan") { |
| spark.range(10).selectExpr("id", "id % 3 as p") |
| .write.partitionBy("p").saveAsTable("testDataForScan") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala |
| index 0ab8691801d..7b81f3a8f6d 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala |
| @@ -17,7 +17,9 @@ |
| |
| package org.apache.spark.sql.execution.python |
| |
| +import org.apache.spark.sql.IgnoreCometNativeDataFusion |
| import org.apache.spark.sql.catalyst.plans.logical.{ArrowEvalPython, BatchEvalPython, Limit, LocalLimit} |
| +import org.apache.spark.sql.comet._ |
| import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest} |
| import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan |
| @@ -93,7 +95,8 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| assert(arrowEvalNodes.size == 2) |
| } |
| |
| - test("Python UDF should not break column pruning/filter pushdown -- Parquet V1") { |
| + test("Python UDF should not break column pruning/filter pushdown -- Parquet V1", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) { |
| withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { |
| withTempPath { f => |
| spark.range(10).select($"id".as("a"), $"id".as("b")) |
| @@ -108,6 +111,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: FileSourceScanExec => scan |
| + case scan: CometScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| assert(scanNodes.head.output.map(_.name) == Seq("a")) |
| @@ -120,11 +124,16 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: FileSourceScanExec => scan |
| + case scan: CometScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| // $"a" is not null and $"a" > 1 |
| - assert(scanNodes.head.dataFilters.length == 2) |
| - assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a")) |
| + val dataFilters = scanNodes.head match { |
| + case scan: FileSourceScanExec => scan.dataFilters |
| + case scan: CometScanExec => scan.dataFilters |
| + } |
| + assert(dataFilters.length == 2) |
| + assert(dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a")) |
| } |
| } |
| } |
| @@ -145,6 +154,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: BatchScanExec => scan |
| + case scan: CometBatchScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| assert(scanNodes.head.output.map(_.name) == Seq("a")) |
| @@ -157,6 +167,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { |
| |
| val scanNodes = query.queryExecution.executedPlan.collect { |
| case scan: BatchScanExec => scan |
| + case scan: CometBatchScanExec => scan |
| } |
| assert(scanNodes.length == 1) |
| // $"a" is not null and $"a" > 1 |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala |
| index d083cac48ff..3c11bcde807 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala |
| @@ -37,8 +37,10 @@ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, |
| import org.apache.spark.sql.streaming.util.StreamManualClock |
| import org.apache.spark.util.Utils |
| |
| +// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow. |
| +// Since it isn't related to Comet, we disable it for now. |
| class AsyncProgressTrackingMicroBatchExecutionSuite |
| - extends StreamTest with BeforeAndAfter with Matchers { |
| + extends StreamTest with BeforeAndAfter with Matchers with IgnoreCometSuite { |
| |
| import testImplicits._ |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala |
| index 746f289c393..7a6a88a9fce 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala |
| @@ -19,16 +19,19 @@ package org.apache.spark.sql.sources |
| |
| import scala.util.Random |
| |
| +import org.apache.comet.CometConf |
| + |
| import org.apache.spark.sql._ |
| import org.apache.spark.sql.catalyst.catalog.BucketSpec |
| import org.apache.spark.sql.catalyst.expressions |
| import org.apache.spark.sql.catalyst.expressions._ |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning |
| import org.apache.spark.sql.catalyst.types.DataTypeUtils |
| -import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan} |
| +import org.apache.spark.sql.comet._ |
| +import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan} |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} |
| import org.apache.spark.sql.execution.datasources.BucketingUtils |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} |
| import org.apache.spark.sql.execution.joins.SortMergeJoinExec |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -102,12 +105,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| } |
| |
| - private def getFileScan(plan: SparkPlan): FileSourceScanExec = { |
| - val fileScan = collect(plan) { case f: FileSourceScanExec => f } |
| + private def getFileScan(plan: SparkPlan): SparkPlan = { |
| + val fileScan = collect(plan) { |
| + case f: FileSourceScanExec => f |
| + case f: CometScanExec => f |
| + case f: CometNativeScanExec => f |
| + } |
| assert(fileScan.nonEmpty, plan) |
| fileScan.head |
| } |
| |
| + private def getBucketScan(plan: SparkPlan): Boolean = getFileScan(plan) match { |
| + case fs: FileSourceScanExec => fs.bucketedScan |
| + case bs: CometScanExec => bs.bucketedScan |
| + case ns: CometNativeScanExec => ns.bucketedScan |
| + } |
| + |
| // To verify if the bucket pruning works, this function checks two conditions: |
| // 1) Check if the pruned buckets (before filtering) are empty. |
| // 2) Verify the final result is the same as the expected one |
| @@ -156,7 +169,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition) |
| .queryExecution.executedPlan |
| val fileScan = getFileScan(planWithoutBucketedScan) |
| - assert(!fileScan.bucketedScan, s"except no bucketed scan but found\n$fileScan") |
| + val bucketedScan = getBucketScan(planWithoutBucketedScan) |
| + assert(!bucketedScan, s"except no bucketed scan but found\n$fileScan") |
| |
| val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType |
| val rowsWithInvalidBuckets = fileScan.execute().filter(row => { |
| @@ -452,28 +466,54 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) { |
| val executedPlan = |
| joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan |
| - assert(executedPlan.isInstanceOf[SortMergeJoinExec]) |
| - executedPlan.asInstanceOf[SortMergeJoinExec] |
| + executedPlan match { |
| + case s: SortMergeJoinExec => s |
| + case b: CometSortMergeJoinExec => |
| + b.originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| } else { |
| val executedPlan = joined.queryExecution.executedPlan |
| - assert(executedPlan.isInstanceOf[SortMergeJoinExec]) |
| - executedPlan.asInstanceOf[SortMergeJoinExec] |
| + executedPlan match { |
| + case s: SortMergeJoinExec => s |
| + case ColumnarToRowExec(child) => |
| + child.asInstanceOf[CometSortMergeJoinExec].originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case CometColumnarToRowExec(child) => |
| + child.asInstanceOf[CometSortMergeJoinExec].originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case CometNativeColumnarToRowExec(child) => |
| + child.asInstanceOf[CometSortMergeJoinExec].originalPlan match { |
| + case s: SortMergeJoinExec => s |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| + case o => fail(s"expected SortMergeJoinExec, but found\n$o") |
| + } |
| } |
| |
| // check existence of shuffle |
| assert( |
| - joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft, |
| + joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft, |
| s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") |
| assert( |
| - joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight, |
| + joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight, |
| s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") |
| |
| // check existence of sort |
| assert( |
| - joinOperator.left.exists(_.isInstanceOf[SortExec]) == sortLeft, |
| + joinOperator.left.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] && |
| + op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortLeft, |
| s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}") |
| assert( |
| - joinOperator.right.exists(_.isInstanceOf[SortExec]) == sortRight, |
| + joinOperator.right.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] && |
| + op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortRight, |
| s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}") |
| |
| // check the output partitioning |
| @@ -836,11 +876,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") |
| |
| val scanDF = spark.table("bucketed_table").select("j") |
| - assert(!getFileScan(scanDF.queryExecution.executedPlan).bucketedScan) |
| + assert(!getBucketScan(scanDF.queryExecution.executedPlan)) |
| checkAnswer(scanDF, df1.select("j")) |
| |
| val aggDF = spark.table("bucketed_table").groupBy("j").agg(max("k")) |
| - assert(!getFileScan(aggDF.queryExecution.executedPlan).bucketedScan) |
| + assert(!getBucketScan(aggDF.queryExecution.executedPlan)) |
| checkAnswer(aggDF, df1.groupBy("j").agg(max("k"))) |
| } |
| } |
| @@ -895,7 +935,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| |
| test("SPARK-29655 Read bucketed tables obeys spark.sql.shuffle.partitions") { |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| withSQLConf( |
| + CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "5", |
| SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") { |
| val bucketSpec = Some(BucketSpec(6, Seq("i", "j"), Nil)) |
| @@ -914,7 +957,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| |
| test("SPARK-32767 Bucket join should work if SHUFFLE_PARTITIONS larger than bucket number") { |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| withSQLConf( |
| + CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "9", |
| SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10") { |
| |
| @@ -944,7 +990,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| } |
| |
| test("bucket coalescing eliminates shuffle") { |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| withSQLConf( |
| + CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false", |
| SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { |
| // The side with bucketedTableTestSpec1 will be coalesced to have 4 output partitions. |
| @@ -1029,15 +1078,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti |
| Seq(true, false).foreach { aqeEnabled => |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) { |
| val plan = sql(query).queryExecution.executedPlan |
| - val shuffles = collect(plan) { case s: ShuffleExchangeExec => s } |
| + val shuffles = collect(plan) { case s: ShuffleExchangeLike => s } |
| assert(shuffles.length == expectedNumShuffles) |
| |
| val scans = collect(plan) { |
| case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f |
| + case b: CometScanExec if b.optionalNumCoalescedBuckets.isDefined => b |
| + case b: CometNativeScanExec if b.optionalNumCoalescedBuckets.isDefined => b |
| } |
| if (expectedCoalescedNumBuckets.isDefined) { |
| assert(scans.length == 1) |
| - assert(scans.head.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + scans.head match { |
| + case f: FileSourceScanExec => |
| + assert(f.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + case b: CometScanExec => |
| + assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + case b: CometNativeScanExec => |
| + assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) |
| + } |
| } else { |
| assert(scans.isEmpty) |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala |
| index 6f897a9c0b7..b0723634f68 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala |
| @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources |
| import java.io.File |
| |
| import org.apache.spark.SparkException |
| +import org.apache.spark.sql.IgnoreCometSuite |
| import org.apache.spark.sql.catalyst.TableIdentifier |
| import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTableType} |
| import org.apache.spark.sql.catalyst.parser.ParseException |
| @@ -27,7 +28,10 @@ import org.apache.spark.sql.internal.SQLConf.BUCKETING_MAX_BUCKETS |
| import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.util.Utils |
| |
| -class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession { |
| +// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow. |
| +// Since it isn't related to Comet, we disable it for now. |
| +class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession |
| + with IgnoreCometSuite { |
| import testImplicits._ |
| |
| protected override lazy val sql = spark.sql _ |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala |
| index d675503a8ba..f220892396e 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala |
| @@ -18,6 +18,7 @@ |
| package org.apache.spark.sql.sources |
| |
| import org.apache.spark.sql.QueryTest |
| +import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} |
| import org.apache.spark.sql.execution.FileSourceScanExec |
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.internal.SQLConf |
| @@ -68,7 +69,11 @@ abstract class DisableUnnecessaryBucketedScanSuite |
| |
| def checkNumBucketedScan(query: String, expectedNumBucketedScan: Int): Unit = { |
| val plan = sql(query).queryExecution.executedPlan |
| - val bucketedScan = collect(plan) { case s: FileSourceScanExec if s.bucketedScan => s } |
| + val bucketedScan = collect(plan) { |
| + case s: FileSourceScanExec if s.bucketedScan => s |
| + case s: CometScanExec if s.bucketedScan => s |
| + case s: CometNativeScanExec if s.bucketedScan => s |
| + } |
| assert(bucketedScan.length == expectedNumBucketedScan) |
| } |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala |
| index 7f6fa2a123e..c778b4e2c48 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala |
| @@ -35,6 +35,7 @@ import org.apache.spark.paths.SparkPath |
| import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} |
| import org.apache.spark.sql.{AnalysisException, DataFrame} |
| import org.apache.spark.sql.catalyst.util.stringToFile |
| +import org.apache.spark.sql.comet.CometBatchScanExec |
| import org.apache.spark.sql.execution.DataSourceScanExec |
| import org.apache.spark.sql.execution.datasources._ |
| import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, FileScan, FileTable} |
| @@ -777,6 +778,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite { |
| val fileScan = df.queryExecution.executedPlan.collect { |
| case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] => |
| batch.scan.asInstanceOf[FileScan] |
| + case batch: CometBatchScanExec if batch.scan.isInstanceOf[FileScan] => |
| + batch.scan.asInstanceOf[FileScan] |
| }.headOption.getOrElse { |
| fail(s"No FileScan in query\n${df.queryExecution}") |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala |
| index c97979a57a5..45a998db0e0 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala |
| @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpressi |
| import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} |
| import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| +import org.apache.spark.sql.comet.CometLocalLimitExec |
| import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} |
| import org.apache.spark.sql.execution.command.ExplainCommand |
| import org.apache.spark.sql.execution.streaming._ |
| @@ -1114,11 +1115,12 @@ class StreamSuite extends StreamTest { |
| val localLimits = execPlan.collect { |
| case l: LocalLimitExec => l |
| case l: StreamingLocalLimitExec => l |
| + case l: CometLocalLimitExec => l |
| } |
| |
| require( |
| localLimits.size == 1, |
| - s"Cant verify local limit optimization with this plan:\n$execPlan") |
| + s"Cant verify local limit optimization ${localLimits.size} with this plan:\n$execPlan") |
| |
| if (expectStreamingLimit) { |
| assert( |
| @@ -1126,7 +1128,8 @@ class StreamSuite extends StreamTest { |
| s"Local limit was not StreamingLocalLimitExec:\n$execPlan") |
| } else { |
| assert( |
| - localLimits.head.isInstanceOf[LocalLimitExec], |
| + localLimits.head.isInstanceOf[LocalLimitExec] || |
| + localLimits.head.isInstanceOf[CometLocalLimitExec], |
| s"Local limit was not LocalLimitExec:\n$execPlan") |
| } |
| } |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala |
| index b4c4ec7acbf..20579284856 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala |
| @@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils |
| import org.scalatest.Assertions |
| |
| import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution |
| +import org.apache.spark.sql.comet.CometHashAggregateExec |
| import org.apache.spark.sql.execution.aggregate.BaseAggregateExec |
| import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec} |
| import org.apache.spark.sql.functions.count |
| @@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest |
| // verify aggregations in between, except partial aggregation |
| val allAggregateExecs = query.lastExecution.executedPlan.collect { |
| case a: BaseAggregateExec => a |
| + case c: CometHashAggregateExec => c.originalPlan |
| } |
| |
| val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { |
| @@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest |
| // verify aggregations in between, except partial aggregation |
| val allAggregateExecs = executedPlan.collect { |
| case a: BaseAggregateExec => a |
| + case c: CometHashAggregateExec => c.originalPlan |
| } |
| |
| val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala |
| index aad91601758..201083bd621 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala |
| @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation |
| import org.apache.spark.sql.{DataFrame, Row, SparkSession} |
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning |
| -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec |
| +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike |
| import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} |
| import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId} |
| import org.apache.spark.sql.functions._ |
| @@ -619,14 +619,27 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { |
| |
| val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS) |
| |
| - assert(query.lastExecution.executedPlan.collect { |
| - case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _, |
| - ShuffleExchangeExec(opA: HashPartitioning, _, _, _), |
| - ShuffleExchangeExec(opB: HashPartitioning, _, _, _)) |
| - if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") |
| - && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") |
| - && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j |
| - }.size == 1) |
| + val join = query.lastExecution.executedPlan.collect { |
| + case j: StreamingSymmetricHashJoinExec => j |
| + }.head |
| + val opA = join.left.collect { |
| + case s: ShuffleExchangeLike |
| + if s.outputPartitioning.isInstanceOf[HashPartitioning] && |
| + partitionExpressionsColumns( |
| + s.outputPartitioning |
| + .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => |
| + s.outputPartitioning.asInstanceOf[HashPartitioning] |
| + }.head |
| + val opB = join.right.collect { |
| + case s: ShuffleExchangeLike |
| + if s.outputPartitioning.isInstanceOf[HashPartitioning] && |
| + partitionExpressionsColumns( |
| + s.outputPartitioning |
| + .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => |
| + s.outputPartitioning |
| + .asInstanceOf[HashPartitioning] |
| + }.head |
| + assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions) |
| }) |
| } |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala |
| index b5cf13a9c12..ac17603fb7f 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala |
| @@ -36,7 +36,7 @@ import org.scalatestplus.mockito.MockitoSugar |
| |
| import org.apache.spark.{SparkException, TestUtils} |
| import org.apache.spark.internal.Logging |
| -import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Row, SaveMode} |
| +import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, IgnoreCometNativeDataFusion, Row, SaveMode} |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} |
| import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LocalRelation} |
| @@ -660,7 +660,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi |
| ) |
| } |
| |
| - test("SPARK-41198: input row calculation with CTE") { |
| + test("SPARK-41198: input row calculation with CTE", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3315")) { |
| withTable("parquet_tbl", "parquet_streaming_tbl") { |
| spark.range(0, 10).selectExpr("id AS col1", "id AS col2") |
| .write.format("parquet").saveAsTable("parquet_tbl") |
| @@ -712,7 +713,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi |
| } |
| } |
| |
| - test("SPARK-41199: input row calculation with mixed-up of DSv1 and DSv2 streaming sources") { |
| + test("SPARK-41199: input row calculation with mixed-up of DSv1 and DSv2 streaming sources", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3315")) { |
| withTable("parquet_streaming_tbl") { |
| val streamInput = MemoryStream[Int] |
| val streamDf = streamInput.toDF().selectExpr("value AS key", "value AS value_stream") |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala |
| index 8f099c31e6b..ce4b7ad25b3 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSelfUnionSuite.scala |
| @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming |
| import org.scalatest.BeforeAndAfter |
| import org.scalatest.concurrent.PatienceConfiguration.Timeout |
| |
| -import org.apache.spark.sql.SaveMode |
| +import org.apache.spark.sql.{IgnoreCometNativeDataFusion, SaveMode} |
| import org.apache.spark.sql.connector.catalog.Identifier |
| import org.apache.spark.sql.execution.streaming.MemoryStream |
| import org.apache.spark.sql.streaming.test.{InMemoryStreamTable, InMemoryStreamTableCatalog} |
| @@ -42,7 +42,8 @@ class StreamingSelfUnionSuite extends StreamTest with BeforeAndAfter { |
| sqlContext.streams.active.foreach(_.stop()) |
| } |
| |
| - test("self-union, DSv1, read via DataStreamReader API") { |
| + test("self-union, DSv1, read via DataStreamReader API", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3401")) { |
| withTempPath { dir => |
| val dataLocation = dir.getAbsolutePath |
| spark.range(1, 4).write.format("parquet").save(dataLocation) |
| @@ -66,7 +67,8 @@ class StreamingSelfUnionSuite extends StreamTest with BeforeAndAfter { |
| } |
| } |
| |
| - test("self-union, DSv1, read via table API") { |
| + test("self-union, DSv1, read via table API", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3401")) { |
| withTable("parquet_streaming_tbl") { |
| spark.sql("CREATE TABLE parquet_streaming_tbl (key integer) USING parquet") |
| |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala |
| index abe606ad9c1..2d930b64cca 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala |
| @@ -22,7 +22,7 @@ import java.util |
| |
| import org.scalatest.BeforeAndAfter |
| |
| -import org.apache.spark.sql.{AnalysisException, Row, SaveMode} |
| +import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row, SaveMode} |
| import org.apache.spark.sql.catalyst.TableIdentifier |
| import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException |
| import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} |
| @@ -327,7 +327,8 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { |
| } |
| } |
| |
| - test("explain with table on DSv1 data source") { |
| + test("explain with table on DSv1 data source", |
| + IgnoreComet("Comet explain output is different")) { |
| val tblSourceName = "tbl_src" |
| val tblTargetName = "tbl_target" |
| val tblSourceQualified = s"default.$tblSourceName" |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala |
| index e937173a590..7d20538bc68 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala |
| @@ -27,6 +27,7 @@ import scala.concurrent.duration._ |
| import scala.language.implicitConversions |
| import scala.util.control.NonFatal |
| |
| +import org.apache.comet.CometConf |
| import org.apache.hadoop.fs.Path |
| import org.scalactic.source.Position |
| import org.scalatest.{BeforeAndAfterAll, Suite, Tag} |
| @@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest |
| import org.apache.spark.sql.catalyst.plans.PlanTestBase |
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan |
| import org.apache.spark.sql.catalyst.util._ |
| +import org.apache.spark.sql.comet._ |
| import org.apache.spark.sql.execution.FilterExec |
| import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution |
| import org.apache.spark.sql.execution.datasources.DataSourceUtils |
| @@ -119,6 +121,34 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with |
| |
| override protected def test(testName: String, testTags: Tag*)(testFun: => Any) |
| (implicit pos: Position): Unit = { |
| + // Check Comet skip tags first, before DisableAdaptiveExecution handling |
| + if (isCometEnabled && testTags.exists(_.isInstanceOf[IgnoreComet])) { |
| + ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun) |
| + return |
| + } |
| + if (isCometEnabled) { |
| + val cometScanImpl = CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) |
| + val isNativeIcebergCompat = cometScanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT || |
| + cometScanImpl == CometConf.SCAN_AUTO |
| + val isNativeDataFusion = cometScanImpl == CometConf.SCAN_NATIVE_DATAFUSION || |
| + cometScanImpl == CometConf.SCAN_AUTO |
| + if (isNativeIcebergCompat && |
| + testTags.exists(_.isInstanceOf[IgnoreCometNativeIcebergCompat])) { |
| + ignore(testName + " (disabled for NATIVE_ICEBERG_COMPAT)", testTags: _*)(testFun) |
| + return |
| + } |
| + if (isNativeDataFusion && |
| + testTags.exists(_.isInstanceOf[IgnoreCometNativeDataFusion])) { |
| + ignore(testName + " (disabled for NATIVE_DATAFUSION)", testTags: _*)(testFun) |
| + return |
| + } |
| + if ((isNativeDataFusion || isNativeIcebergCompat) && |
| + testTags.exists(_.isInstanceOf[IgnoreCometNativeScan])) { |
| + ignore(testName + " (disabled for NATIVE_DATAFUSION and NATIVE_ICEBERG_COMPAT)", |
| + testTags: _*)(testFun) |
| + return |
| + } |
| + } |
| if (testTags.exists(_.isInstanceOf[DisableAdaptiveExecution])) { |
| super.test(testName, testTags: _*) { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { |
| @@ -242,6 +272,29 @@ private[sql] trait SQLTestUtilsBase |
| protected override def _sqlContext: SQLContext = self.spark.sqlContext |
| } |
| |
| + /** |
| + * Whether Comet extension is enabled |
| + */ |
| + protected def isCometEnabled: Boolean = SparkSession.isCometEnabled |
| + |
| + /** |
| + * Whether to enable ansi mode This is only effective when |
| + * [[isCometEnabled]] returns true. |
| + */ |
| + protected def enableCometAnsiMode: Boolean = { |
| + val v = System.getenv("ENABLE_COMET_ANSI_MODE") |
| + v != null && v.toBoolean |
| + } |
| + |
| + /** |
| + * Whether Spark should only apply Comet scan optimization. This is only effective when |
| + * [[isCometEnabled]] returns true. |
| + */ |
| + protected def isCometScanOnly: Boolean = { |
| + val v = System.getenv("ENABLE_COMET_SCAN_ONLY") |
| + v != null && v.toBoolean |
| + } |
| + |
| protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { |
| SparkSession.setActiveSession(spark) |
| super.withSQLConf(pairs: _*)(f) |
| @@ -435,6 +488,8 @@ private[sql] trait SQLTestUtilsBase |
| val schema = df.schema |
| val withoutFilters = df.queryExecution.executedPlan.transform { |
| case FilterExec(_, child) => child |
| + case CometFilterExec(_, _, _, _, child, _) => child |
| + case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child |
| } |
| |
| spark.internalCreateDataFrame(withoutFilters.execute(), schema) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala |
| index ed2e309fa07..a5ea58146ad 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala |
| @@ -74,6 +74,31 @@ trait SharedSparkSessionBase |
| // this rule may potentially block testing of other optimization rules such as |
| // ConstantPropagation etc. |
| .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) |
| + // Enable Comet if `ENABLE_COMET` environment variable is set |
| + if (isCometEnabled) { |
| + conf |
| + .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") |
| + .set("spark.comet.enabled", "true") |
| + .set("spark.comet.parquet.respectFilterPushdown", "true") |
| + |
| + if (!isCometScanOnly) { |
| + conf |
| + .set("spark.comet.exec.enabled", "true") |
| + .set("spark.shuffle.manager", |
| + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") |
| + .set("spark.comet.exec.shuffle.enabled", "true") |
| + .set("spark.comet.memoryOverhead", "10g") |
| + } else { |
| + conf |
| + .set("spark.comet.exec.enabled", "false") |
| + .set("spark.comet.exec.shuffle.enabled", "false") |
| + } |
| + |
| + if (enableCometAnsiMode) { |
| + conf |
| + .set("spark.sql.ansi.enabled", "true") |
| + } |
| + } |
| conf.set( |
| StaticSQLConf.WAREHOUSE_PATH, |
| conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) |
| diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala |
| index c63c748953f..7edca9c93a6 100644 |
| --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala |
| +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala |
| @@ -45,7 +45,7 @@ class SqlResourceWithActualMetricsSuite |
| import testImplicits._ |
| |
| // Exclude nodes which may not have the metrics |
| - val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject") |
| + val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject", "RowToColumnar") |
| |
| implicit val formats = new DefaultFormats { |
| override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss") |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala |
| index 52abd248f3a..7a199931a08 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala |
| @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive |
| |
| import org.apache.spark.sql._ |
| import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} |
| +import org.apache.spark.sql.comet._ |
| import org.apache.spark.sql.execution._ |
| import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} |
| import org.apache.spark.sql.hive.execution.HiveTableScanExec |
| @@ -35,6 +36,9 @@ abstract class DynamicPartitionPruningHiveScanSuiteBase |
| case s: FileSourceScanExec => s.partitionFilters.collect { |
| case d: DynamicPruningExpression => d.child |
| } |
| + case s: CometScanExec => s.partitionFilters.collect { |
| + case d: DynamicPruningExpression => d.child |
| + } |
| case h: HiveTableScanExec => h.partitionPruningPred.collect { |
| case d: DynamicPruningExpression => d.child |
| } |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala |
| index de3b1ffccf0..2a76d127093 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala |
| @@ -23,14 +23,15 @@ import java.util.concurrent.{Executors, TimeUnit} |
| import org.scalatest.BeforeAndAfterEach |
| |
| import org.apache.spark.metrics.source.HiveCatalogMetrics |
| -import org.apache.spark.sql.QueryTest |
| +import org.apache.spark.sql.{IgnoreCometSuite, QueryTest} |
| import org.apache.spark.sql.execution.datasources.FileStatusCache |
| import org.apache.spark.sql.hive.test.TestHiveSingleton |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.test.SQLTestUtils |
| |
| class PartitionedTablePerfStatsSuite |
| - extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach { |
| + extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach |
| + with IgnoreCometSuite { |
| |
| override def beforeEach(): Unit = { |
| super.beforeEach() |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala |
| index f3be79f9022..b4b1ea8dbc4 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala |
| @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn |
| import org.apache.hadoop.io.{LongWritable, Writable} |
| |
| import org.apache.spark.{SparkException, SparkFiles, TestUtils} |
| -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} |
| +import org.apache.spark.sql.{AnalysisException, IgnoreCometNativeDataFusion, QueryTest, Row} |
| import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode |
| import org.apache.spark.sql.catalyst.plans.logical.Project |
| import org.apache.spark.sql.execution.WholeStageCodegenExec |
| @@ -448,7 +448,8 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { |
| } |
| } |
| |
| - test("SPARK-11522 select input_file_name from non-parquet table") { |
| + test("SPARK-11522 select input_file_name from non-parquet table", |
| + IgnoreCometNativeDataFusion("https://github.com/apache/datafusion-comet/issues/3312")) { |
| |
| withTempDir { tempDir => |
| |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala |
| index 6160c3e5f6c..0956d7d9edc 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala |
| @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} |
| import java.util.{Locale, Set} |
| |
| import com.google.common.io.Files |
| +import org.apache.comet.CometConf |
| import org.apache.hadoop.fs.{FileSystem, Path} |
| |
| import org.apache.spark.{SparkException, TestUtils} |
| @@ -838,8 +839,13 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi |
| } |
| |
| test("SPARK-2554 SumDistinct partial aggregation") { |
| - checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), |
| - sql("SELECT distinct key FROM src order by key").collect().toSeq) |
| + // Range partitioning uses random samples, so per-partition comparisons do not always yield |
| + // the same results. Disable Comet native range partitioning. |
| + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key -> "false") |
| + { |
| + checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), |
| + sql("SELECT distinct key FROM src order by key").collect().toSeq) |
| + } |
| } |
| |
| test("SPARK-4963 DataFrame sample on mutable row return wrong result") { |
| diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala |
| index 1d646f40b3e..5babe505301 100644 |
| --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala |
| +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala |
| @@ -53,25 +53,54 @@ object TestHive |
| new SparkContext( |
| System.getProperty("spark.sql.test.master", "local[1]"), |
| "TestSQLContext", |
| - new SparkConf() |
| - .set("spark.sql.test", "") |
| - .set(SQLConf.CODEGEN_FALLBACK.key, "false") |
| - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) |
| - .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, |
| - "org.apache.spark.sql.hive.execution.PairSerDe") |
| - .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) |
| - // SPARK-8910 |
| - .set(UI_ENABLED, false) |
| - .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) |
| - // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes |
| - // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764. |
| - .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false") |
| - // Disable ConvertToLocalRelation for better test coverage. Test cases built on |
| - // LocalRelation will exercise the optimization rules better by disabling it as |
| - // this rule may potentially block testing of other optimization rules such as |
| - // ConstantPropagation etc. |
| - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName))) |
| + { |
| + val conf = new SparkConf() |
| + .set("spark.sql.test", "") |
| + .set(SQLConf.CODEGEN_FALLBACK.key, "false") |
| + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) |
| + .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, |
| + "org.apache.spark.sql.hive.execution.PairSerDe") |
| + .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) |
| + // SPARK-8910 |
| + .set(UI_ENABLED, false) |
| + .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) |
| + // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes |
| + // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764. |
| + .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false") |
| + // Disable ConvertToLocalRelation for better test coverage. Test cases built on |
| + // LocalRelation will exercise the optimization rules better by disabling it as |
| + // this rule may potentially block testing of other optimization rules such as |
| + // ConstantPropagation etc. |
| + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) |
| + |
| + if (SparkSession.isCometEnabled) { |
| + conf |
| + .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") |
| + .set("spark.comet.enabled", "true") |
| + |
| + val v = System.getenv("ENABLE_COMET_SCAN_ONLY") |
| + if (v == null || !v.toBoolean) { |
| + conf |
| + .set("spark.comet.exec.enabled", "true") |
| + .set("spark.shuffle.manager", |
| + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") |
| + .set("spark.comet.exec.shuffle.enabled", "true") |
| + } else { |
| + conf |
| + .set("spark.comet.exec.enabled", "false") |
| + .set("spark.comet.exec.shuffle.enabled", "false") |
| + } |
| + |
| + val a = System.getenv("ENABLE_COMET_ANSI_MODE") |
| + if (a != null && a.toBoolean) { |
| + conf |
| + .set("spark.sql.ansi.enabled", "true") |
| + } |
| + } |
| |
| + conf |
| + } |
| + )) |
| |
| case class TestHiveVersion(hiveClient: HiveClient) |
| extends TestHiveContext(TestHive.sparkContext, hiveClient) |