blob: 406e56cdaef870bff1fd8b927343ff84c224928d [file]
diff --git a/pom.xml b/pom.xml
index 22922143fc3..7c56e5e8641 100644
--- a/pom.xml
+++ b/pom.xml
@@ -148,6 +148,8 @@
<kryo.version>4.0.3</kryo.version>
<ivy.version>2.5.3</ivy.version>
<oro.version>2.0.8</oro.version>
+ <spark.version.short>4.0</spark.version.short>
+ <comet.version>0.14.0-SNAPSHOT</comet.version>
<!--
If you change codahale.metrics.version, you also need to change
the link to metrics.dropwizard.io in docs/monitoring.md.
@@ -2596,6 +2598,25 @@
<artifactId>arpack</artifactId>
<version>${netlib.ludovic.dev.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.datafusion</groupId>
+ <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
+ <version>${comet.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
<!-- SPARK-16484 add `datasketches-java` for support Datasketches HllSketch -->
<dependency>
<groupId>org.apache.datasketches</groupId>
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index dcf6223a98b..0458a5bb640 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -90,6 +90,10 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.datafusion</groupId>
+ <artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
+ </dependency>
<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
index 0015d7ff99e..c9dd85e72c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
@@ -1042,6 +1042,23 @@ object SparkSession extends SparkSessionCompanion with Logging {
extensions
}
+ /**
+ * Whether Comet extension is enabled
+ */
+ def isCometEnabled: Boolean = {
+ val v = System.getenv("ENABLE_COMET")
+ v == null || v.toBoolean
+ }
+
+
+ private def loadCometExtension(sparkContext: SparkContext): Seq[String] = {
+ if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled)) {
+ Seq("org.apache.comet.CometSparkSessionExtensions")
+ } else {
+ Seq.empty
+ }
+ }
+
/**
* Initialize extensions specified in [[StaticSQLConf]]. The classes will be applied to the
* extensions passed into this function.
@@ -1051,7 +1068,8 @@ object SparkSession extends SparkSessionCompanion with Logging {
extensions: SparkSessionExtensions): SparkSessionExtensions = {
val extensionConfClassNames = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
.getOrElse(Seq.empty)
- extensionConfClassNames.foreach { extensionConfClassName =>
+ val extensionClassNames = extensionConfClassNames ++ loadCometExtension(sparkContext)
+ extensionClassNames.foreach { extensionConfClassName =>
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.getConstructor().newInstance()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 4410fe50912..43bcce2a038 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.plans.logical.{EmptyRelation, LogicalPlan}
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.LogicalQueryStage
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -84,6 +85,7 @@ private[execution] object SparkPlanInfo {
// dump the file scan metadata (e.g file path) to event log
val metadata = plan match {
case fileScan: FileSourceScanLike => fileScan.metadata
+ case cometScan: CometScanExec => cometScan.metadata
case _ => Map[String, String]()
}
val childrenInfo = children.flatMap {
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
index 7aca17dcb25..8afeb3b4a2f 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
@@ -64,15 +64,6 @@ WithCTE
+- CTERelationRef xxxx, true, [c1#x], false, false
--- !query
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)
--- !query analysis
-Aggregate [lower(listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0)) AS lower(listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST))#x]
-+- SubqueryAlias t
- +- Project [col1#x AS c1#x]
- +- LocalRelation [col1#x]
-
-
-- !query
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t
-- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
index 17815ed5dde..baad440b1ce 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
@@ -1,3 +1,6 @@
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
-- test cases for collation support
-- Create a test table with data
diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
index 13bbd9d81b7..541cdfb1e04 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
@@ -15,6 +15,12 @@
-- limitations under the License.
--
+-- TODO: Disabled due to one of the test failed for Spark4.0
+-- TODO: https://github.com/apache/datafusion-comet/issues/1948
+-- The following query failed
+-- select /*+ COALESCE(1) */ id, a+b, a-b, a*b, a/b from decimals_test order by id
+--SET spark.comet.enabled = false
+
CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b;
-- division, remainder and pmod by 0 return NULL
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
index 7aef901da4f..f3d6e18926d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
@@ -2,3 +2,4 @@
--SET spark.sql.adaptive.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
index eeb2180f7a5..afd1b5ec289 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain-cbo.sql
@@ -1,5 +1,6 @@
--SET spark.sql.cbo.enabled=true
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
CREATE TABLE explain_temp1(a INT, b INT) USING PARQUET;
CREATE TABLE explain_temp2(c INT, d INT) USING PARQUET;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql
index 698ca009b4f..57d774a3617 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql
@@ -1,6 +1,7 @@
--SET spark.sql.codegen.wholeStage = true
--SET spark.sql.adaptive.enabled = false
--SET spark.sql.maxMetadataStringLength = 500
+--SET spark.comet.enabled = false
-- Test tables
CREATE table explain_temp1 (key int, val int) USING PARQUET;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
index aa3d02dc2fb..c4f878d9908 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
@@ -5,7 +5,9 @@ WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY col1) FROM (VALUES ('
-- Test cases with utf8_lcase. Lower expression added for determinism
SELECT lower(listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1);
WITH t(c1) AS (SELECT lower(listagg(DISTINCT col1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('A'), ('b'), ('B'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'b') FROM t;
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1);
+-- TODO https://github.com/apache/datafusion-comet/issues/1947
+-- TODO fix Comet for this query
+-- SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1);
-- Test cases with unicode_rtrim.
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t;
WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY col1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc\n'), ('abc'), ('x'))) SELECT replace(replace(c1, ' ', ''), '\n', '$') FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
index 3a409eea348..26e9aaf215c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
@@ -6,6 +6,9 @@
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/int4.sql
--
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
CREATE TABLE INT4_TBL(f1 int) USING parquet;
-- [SPARK-28023] Trim the string when cast string type to other types
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
index fac23b4a26f..98b12ae5ccc 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -6,6 +6,10 @@
-- Test int8 64-bit integers.
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/int8.sql
--
+
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
CREATE TABLE INT8_TBL(q1 bigint, q2 bigint) USING parquet;
-- PostgreSQL implicitly casts string literals to data with integral types, but
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
index 0efe0877e9b..f9df0400c99 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/select_having.sql
@@ -6,6 +6,9 @@
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql
--
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
-- load test data
CREATE TABLE test_having (a int, b int, c string, d string) USING parquet;
INSERT INTO test_having VALUES (0, 1, 'XXXX', 'A');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql
index 7c816d8a416..b1551a2b296 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql
@@ -1,6 +1,23 @@
-- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery
-- It includes correlated cases.
+-- TODO: Disabled due to one of the test failed for Spark4.0
+-- TODO: https://github.com/apache/datafusion-comet/issues/1948
+-- The following query failed
+-- SELECT Count(DISTINCT( t1a )),
+-- t1b
+-- FROM t1
+-- WHERE t1d NOT IN (SELECT t2d
+-- FROM t2
+-- WHERE t2b > t1b
+-- ORDER BY t2b DESC nulls first, t2d
+-- LIMIT 1
+-- OFFSET 1)
+-- GROUP BY t1b
+-- ORDER BY t1b NULLS last
+-- LIMIT 1
+-- OFFSET 1;
+--SET spark.comet.enabled = false
--CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=true
--CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=false
@@ -61,6 +78,7 @@ WHERE t1a IN (SELECT t2a
WHERE t1d = t2d)
LIMIT 2;
+--SET spark.sql.cbo.enabled=true
-- correlated IN subquery
-- LIMIT on both parent and subquery sides
SELECT *
diff --git a/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql b/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql
index e803254ea64..74db78aee38 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/view-schema-binding-config.sql
@@ -1,6 +1,9 @@
-- This test suits check the spark.sql.viewSchemaBindingMode configuration.
-- It can be DISABLED and COMPENSATION
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
-- Verify the default binding is true
SET spark.sql.legacy.viewSchemaBindingMode;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql b/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql
index 21a3ce1e122..f4762ab98f0 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/view-schema-compensation.sql
@@ -1,5 +1,9 @@
-- This test suite checks the WITH SCHEMA COMPENSATION clause
-- Disable ANSI mode to ensure we are forcing it explicitly in the CASTS
+
+-- TODO: https://github.com/apache/datafusion-comet/issues/551
+--SET spark.comet.enabled = false
+
SET spark.sql.ansi.enabled = false;
-- In COMPENSATION views get invalidated if the type can't cast
diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
index 1f8c5822e7d..b7de4e28813 100644
--- a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
@@ -40,14 +40,6 @@ struct<len(c1):int,regexp_count(c1, a):int,regexp_count(c1, b):int>
2 1 1
--- !query
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)
--- !query schema
-struct<lower(listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)):string collate UTF8_LCASE>
--- !query output
-ab
-
-
-- !query
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t
-- !query schema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 0f42502f1d9..f616024a9c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan, SparkPlanInfo}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation}
import org.apache.spark.sql.execution.columnar._
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -520,7 +520,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
df.collect()
}
assert(
- collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.size == expected)
+ collect(df.queryExecution.executedPlan) {
+ case _: ShuffleExchangeLike => 1 }.size == expected)
}
test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
@@ -1626,7 +1627,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}
}
- test("SPARK-35332: Make cache plan disable configs configurable - check AQE") {
+ test("SPARK-35332: Make cache plan disable configs configurable - check AQE",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2",
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
@@ -1661,7 +1664,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
_.nodeName.contains("AdaptiveSparkPlan"))
val aqePlanRoot = findNodeInSparkPlanInfo(inMemoryScanNode.get,
_.nodeName.contains("ResultQueryStage"))
- aqePlanRoot.get.children.head.nodeName == "AQEShuffleRead"
+ aqeNode.get.children.head.nodeName == "AQEShuffleRead" ||
+ (aqeNode.get.children.head.nodeName.contains("WholeStageCodegen") &&
+ aqeNode.get.children.head.children.head.nodeName == "ColumnarToRow" &&
+ aqeNode.get.children.head.children.head.children.head.nodeName == "InputAdapter" &&
+ aqeNode.get.children.head.children.head.children.head.children.head.nodeName ==
+ "AQEShuffleRead")
}
withTempView("t0", "t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 9db406ff12f..abbc91f5c11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -855,7 +855,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)
val exchangePlans = collect(aggPlan) {
- case shuffle: ShuffleExchangeExec => shuffle
+ case shuffle: ShuffleExchangeLike => shuffle
}
assert(exchangePlans.length == 1)
}
@@ -2241,7 +2241,9 @@ class DataFrameAggregateSuite extends QueryTest
}
}
- test("SPARK-47430 Support GROUP BY MapType") {
+ test("SPARK-47430 Support GROUP BY MapType",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
def genMapData(dataType: String): String = {
s"""
|case when id % 4 == 0 then map()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index ed182322aec..1ae6afa686a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -435,7 +435,9 @@ class DataFrameJoinSuite extends QueryTest
withTempDatabase { dbName =>
withTable(table1Name, table2Name) {
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ "spark.comet.enabled" -> "false") {
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5b88eeefeca..d4f07bc182a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -36,11 +36,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.comet.CometBroadcastExchangeExec
import org.apache.spark.sql.connector.FakeV2Provider
import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -1493,7 +1494,7 @@ class DataFrameSuite extends QueryTest
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
- case e: ShuffleExchangeExec => atFirstAgg = false
+ case e: ShuffleExchangeLike => atFirstAgg = false
case _ =>
}
}
@@ -1683,7 +1684,7 @@ class DataFrameSuite extends QueryTest
checkAnswer(join, df)
assert(
collect(join.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => true }.size === 1)
+ case _: ShuffleExchangeLike => true }.size === 1)
assert(
collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1)
val broadcasted = broadcast(join)
@@ -1691,10 +1692,12 @@ class DataFrameSuite extends QueryTest
checkAnswer(join2, df)
assert(
collect(join2.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => true }.size == 1)
+ case _: ShuffleExchangeLike => true }.size == 1)
assert(
collect(join2.queryExecution.executedPlan) {
- case e: BroadcastExchangeExec => true }.size === 1)
+ case e: BroadcastExchangeExec => true
+ case _: CometBroadcastExchangeExec => true
+ }.size === 1)
assert(
collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4)
}
@@ -2092,7 +2095,7 @@ class DataFrameSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
- case h: ShuffleExchangeExec => h
+ case h: ShuffleExchangeLike => h
}
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 01e72daead4..0a8d1e8b9b9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
import org.apache.spark.sql.catalyst.plans.logical.{Window => LogicalWindow}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
@@ -1142,10 +1143,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}
def isShuffleExecByRequirement(
- plan: ShuffleExchangeExec,
+ plan: ShuffleExchangeLike,
desiredClusterColumns: Seq[String]): Boolean = plan match {
case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS, _) =>
partitionExpressionsColumns(op.expressions) === desiredClusterColumns
+ case CometShuffleExchangeExec(op: HashPartitioning, _, _, ENSURE_REQUIREMENTS, _, _) =>
+ partitionExpressionsColumns(op.expressions) === desiredClusterColumns
case _ => false
}
@@ -1168,7 +1171,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest
val shuffleByRequirement = windowed.queryExecution.executedPlan.exists {
case w: WindowExec =>
w.child.exists {
- case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2"))
+ case s: ShuffleExchangeLike => isShuffleExecByRequirement(s, Seq("key1", "key2"))
case _ => false
}
case _ => false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 81713c777bc..b5f92ed9742 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
@@ -2415,7 +2415,7 @@ class DatasetSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
- case h: ShuffleExchangeExec => h
+ case h: ShuffleExchangeLike => h
}
assert(exchanges.size == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index 2c24cc7d570..21d36ebc6f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
@@ -262,6 +263,9 @@ abstract class DynamicPartitionPruningSuiteBase
case s: BatchScanExec => s.runtimeFilters.collect {
case d: DynamicPruningExpression => d.child
}
+ case s: CometScanExec => s.partitionFilters.collect {
+ case d: DynamicPruningExpression => d.child
+ }
case _ => Nil
}
}
@@ -755,7 +759,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("partition pruning in broadcast hash joins") {
+ test("partition pruning in broadcast hash joins",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #1737")) {
Given("disable broadcast pruning and disable subquery duplication")
withSQLConf(
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
@@ -1027,7 +1032,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("avoid reordering broadcast join keys to match input hash partitioning") {
+ test("avoid reordering broadcast join keys to match input hash partitioning",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTable("large", "dimTwo", "dimThree") {
@@ -1215,7 +1221,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " +
- "canonicalization and exchange reuse") {
+ "canonicalization and exchange reuse",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(
@@ -1424,7 +1431,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-34637: DPP side broadcast query stage is created firstly") {
+ test("SPARK-34637: DPP side broadcast query stage is created firstly",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/1839")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
""" WITH v as (
@@ -1455,7 +1463,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("SPARK-35568: Fix UnsupportedOperationException when enabling both AQE and DPP") {
+ test("SPARK-35568: Fix UnsupportedOperationException when enabling both AQE and DPP",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #1737")) {
val df = sql(
"""
|SELECT s.store_id, f.product_id
@@ -1730,6 +1739,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
+ case s: CometScanExec =>
+ s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid")))
case _ => false
}
assert(scanOption.isDefined)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index 9c90e0105a4..fadf2f0f698 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -470,7 +470,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}
- test("Explain formatted output for scan operator for datasource V2") {
+ test("Explain formatted output for scan operator for datasource V2",
+ IgnoreComet("Comet explain output is different")) {
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
@@ -548,7 +549,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
}
-class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite {
+// Ignored when Comet is enabled. Comet changes expected query plans.
+class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite
+ with IgnoreCometSuite {
import testImplicits._
test("SPARK-35884: Explain Formatted") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 9c529d14221..2f1bc3880fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
@@ -967,6 +968,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(bJoinExec.isEmpty)
val smJoinExec = collect(joinedDF.queryExecution.executedPlan) {
case smJoin: SortMergeJoinExec => smJoin
+ case smJoin: CometSortMergeJoinExec => smJoin
}
assert(smJoinExec.nonEmpty)
}
@@ -1027,6 +1029,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
@@ -1068,6 +1071,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
+ case CometBatchScanExec(BatchScanExec(_, f: FileScan, _, _, _, _), _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
@@ -1252,6 +1256,8 @@ class FileBasedDataSourceSuite extends QueryTest
val filters = df.queryExecution.executedPlan.collect {
case f: FileSourceScanLike => f.dataFilters
case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
+ case b: CometScanExec => b.dataFilters
+ case b: CometBatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
}.flatten
assert(filters.contains(GreaterThan(scan.logicalPlan.output.head, Literal(5L))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
new file mode 100644
index 00000000000..5691536c114
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalactic.source.Position
+import org.scalatest.Tag
+
+import org.apache.spark.sql.test.SQLTestUtils
+
+/**
+ * Tests with this tag will be ignored when Comet is enabled (e.g., via `ENABLE_COMET`).
+ */
+case class IgnoreComet(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeIcebergCompat(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeDataFusion(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeScan(reason: String) extends Tag("DisableComet")
+
+/**
+ * Helper trait that disables Comet for all tests regardless of default config values.
+ */
+trait IgnoreCometSuite extends SQLTestUtils {
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
+ (implicit pos: Position): Unit = {
+ if (isCometEnabled) {
+ ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
+ } else {
+ super.test(testName, testTags: _*)(testFun)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
index 7d7185ae6c1..442a5bddeb8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -442,7 +442,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
}
test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " +
- "on the same column") {
+ "on the same column",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " +
@@ -451,7 +452,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
}
test("Runtime bloom filter join: add bloom filter if dpp filter exists on " +
- "a different column") {
+ "a different column",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
assertRewroteWithBloomFilter("select * from bf5part join bf2 on " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 53e47f428c3..a55d8f0c161 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
@@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleHashJoins = collect(executedPlan) {
case s: ShuffledHashJoinExec => s
+ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec]
}
assert(shuffleHashJoins.size == 1)
assert(shuffleHashJoins.head.buildSide == buildSide)
@@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleMergeJoins = collect(executedPlan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c
}
assert(shuffleMergeJoins.size == 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index aaac0ebc9aa..fbef0774d46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION}
-import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, FilterExec, InputAdapter, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins._
@@ -805,7 +806,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
}
- test("test SortMergeJoin (with spill)") {
+ test("test SortMergeJoin (with spill)",
+ IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
@@ -931,10 +933,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val physical = df.queryExecution.sparkPlan
val physicalJoins = physical.collect {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
val executed = df.queryExecution.executedPlan
val executedJoins = collect(executed) {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
// This only applies to the above tested queries, in which a child SortMergeJoin always
// contains the SortOrder required by its parent SortMergeJoin. Thus, SortExec should never
@@ -1180,9 +1184,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
.groupBy($"k1").count()
.queryExecution.executedPlan
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// No extra shuffle before aggregate
- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true }.size === 2)
})
}
@@ -1199,10 +1205,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2)
assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1)
// No extra sort before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 3)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 3)
})
// Test shuffled hash join
@@ -1212,10 +1219,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// No extra sort before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortExec | _: CometSortExec => true }.size === 3)
})
}
@@ -1306,12 +1316,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
inputDFs.foreach { case (df1, df2, joinExprs) =>
val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full")
assert(collect(smjDF.queryExecution.executedPlan) {
- case _: SortMergeJoinExec => true }.size === 1)
+ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 1)
val smjResult = smjDF.collect()
val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full")
assert(collect(shjDF.queryExecution.executedPlan) {
- case _: ShuffledHashJoinExec => true }.size === 1)
+ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
}
@@ -1370,12 +1380,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()
val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
@@ -1386,12 +1398,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()
val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
@@ -1435,13 +1449,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
assert(shjCodegenDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometHashJoinExec))) =>
+ true
+ case WholeStageCodegenExec(ColumnarToRowExec(
+ InputAdapter(CometProjectExec(_, _, _, _, _: CometHashJoinExec, _)))) => true
+ case _: CometHashJoinExec => true
}.size === 1)
checkAnswer(shjCodegenDF, Seq.empty)
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val shjNonCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
assert(shjNonCodegenDF.queryExecution.executedPlan.collect {
- case _: ShuffledHashJoinExec => true }.size === 1)
+ case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
+ }.size === 1)
checkAnswer(shjNonCodegenDF, Seq.empty)
}
}
@@ -1489,7 +1510,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
// Have shuffle before aggregation
- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true }.size === 1)
}
def getJoinQuery(selectExpr: String, joinType: String): String = {
@@ -1518,9 +1540,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// No extra sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 5)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 5)
}
// Test output ordering is not preserved
@@ -1529,9 +1554,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// Have sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 6)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 6)
}
// Test singe partition
@@ -1541,7 +1569,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
|FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
|""".stripMargin)
val plan = fullJoinDF.queryExecution.executedPlan
- assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1)
+ assert(collect(plan) {
+ case _: ShuffleExchangeLike => true}.size == 1)
checkAnswer(fullJoinDF, Row(100))
}
}
@@ -1614,6 +1643,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
Seq(semiJoinDF, antiJoinDF).foreach { df =>
assert(collect(df.queryExecution.executedPlan) {
case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true
+ case j: CometHashJoinExec
+ if j.originalPlan.asInstanceOf[ShuffledHashJoinExec].ignoreDuplicatedKey ==
+ ignoreDuplicatedKey => true
}.size == 1)
}
}
@@ -1658,14 +1690,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") {
def check(plan: SparkPlan): Unit = {
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 1)
}
dupStreamSideColTest("MERGE", check)
}
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SHJ)") {
def check(plan: SparkPlan): Unit = {
- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
+ assert(collect(plan) {
+ case _: ShuffledHashJoinExec => true
+ case _: CometHashJoinExec => true
+ }.size === 1)
}
dupStreamSideColTest("SHUFFLE_HASH", check)
}
@@ -1801,7 +1839,8 @@ class ThreadLeakInSortMergeJoinSuite
sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20))
}
- test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)") {
+ test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)",
+ IgnoreComet("Comet SMJ doesn't spill yet")) {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
index ad424b3a7cc..4ece0117a34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
@@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest
* }}}
*/
// scalastyle:on line.size.limit
-trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite {
+trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite with IgnoreCometSuite {
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f294ff81021..8a3b818ee94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1524,7 +1524,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
}
- test("external sorting updates peak execution memory") {
+ test("external sorting updates peak execution memory",
+ IgnoreComet("TODO: native CometSort does not update peak execution memory")) {
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
@@ -4449,7 +4450,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" +
- " when WSCG is off") {
+ " when WSCG is off",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
@@ -4470,7 +4472,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-39175: Query context of Cast should be serialized to executors" +
- " when WSCG is off") {
+ " when WSCG is off",
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2218")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
@@ -4490,14 +4493,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(ex.isInstanceOf[SparkNumberFormatException] ||
ex.isInstanceOf[SparkDateTimeException] ||
ex.isInstanceOf[SparkRuntimeException])
- assert(ex.getMessage.contains(query))
+
+ if (!isCometEnabled) {
+ // Comet's error message does not include the original SQL query
+ // https://github.com/apache/datafusion-comet/issues/2215
+ assert(ex.getMessage.contains(query))
+ }
}
}
}
}
test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +
- "be serialized to executors when WSCG is off") {
+ "be serialized to executors when WSCG is off",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> "true") {
withTable("t") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index c1c041509c3..7d463e4b85e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -235,6 +235,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, true)
session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.implicits._
@@ -293,6 +295,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, enableAQE)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.implicits._
@@ -331,6 +335,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
val session = SparkSession.builder()
.master("local[1]")
.config(COLUMN_BATCH_SIZE.key, 2)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ .config("spark.comet.enabled", false)
.withExtensions { extensions =>
extensions.injectColumnar(session =>
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 0df7f806272..52d33d67328 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import org.apache.comet.CometConf
+
import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.IsNotNull
@@ -179,29 +181,31 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
}
test("string regex_replace / regex_extract") {
- val df = Seq(
- ("100-200", "(\\d+)-(\\d+)", "300"),
- ("100-200", "(\\d+)-(\\d+)", "400"),
- ("100-200", "(\\d+)", "400")).toDF("a", "b", "c")
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") {
+ val df = Seq(
+ ("100-200", "(\\d+)-(\\d+)", "300"),
+ ("100-200", "(\\d+)-(\\d+)", "400"),
+ ("100-200", "(\\d+)", "400")).toDF("a", "b", "c")
- checkAnswer(
- df.select(
- regexp_replace($"a", "(\\d+)", "num"),
- regexp_replace($"a", $"b", $"c"),
- regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
- Row("num-num", "300", "100") :: Row("num-num", "400", "100") ::
- Row("num-num", "400-400", "100") :: Nil)
-
- // for testing the mutable state of the expression in code gen.
- // This is a hack way to enable the codegen, thus the codegen is enable by default,
- // it will still use the interpretProjection if projection followed by a LocalRelation,
- // hence we add a filter operator.
- // See the optimizer rule `ConvertToLocalRelation`
- checkAnswer(
- df.filter("isnotnull(a)").selectExpr(
- "regexp_replace(a, b, c)",
- "regexp_extract(a, b, 1)"),
- Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
+ checkAnswer(
+ df.select(
+ regexp_replace($"a", "(\\d+)", "num"),
+ regexp_replace($"a", $"b", $"c"),
+ regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
+ Row("num-num", "300", "100") :: Row("num-num", "400", "100") ::
+ Row("num-num", "400-400", "100") :: Nil)
+
+ // for testing the mutable state of the expression in code gen.
+ // This is a hack way to enable the codegen, thus the codegen is enable by default,
+ // it will still use the interpretProjection if projection followed by a LocalRelation,
+ // hence we add a filter operator.
+ // See the optimizer rule `ConvertToLocalRelation`
+ checkAnswer(
+ df.filter("isnotnull(a)").selectExpr(
+ "regexp_replace(a, b, c)",
+ "regexp_extract(a, b, 1)"),
+ Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
+ }
}
test("non-matching optional group") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 2e33f6505ab..e1e93ab3bad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -23,10 +23,11 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project, Sort, Union}
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.FileScanRDD
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -1529,6 +1530,12 @@ class SubquerySuite extends QueryTest
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
_.files.forall(_.urlEncodedPath.contains("p=0"))))
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(
+ fs @ CometScanExec(_, _, _, _, partitionFilters, _, _, _, _, _, _)))) =>
+ partitionFilters.exists(ExecSubqueryExpression.hasSubquery) &&
+ fs.inputRDDs().forall(
+ _.asInstanceOf[FileScanRDD].filePartitions.forall(
+ _.files.forall(_.urlEncodedPath.contains("p=0"))))
case _ => false
})
}
@@ -2094,7 +2101,7 @@ class SubquerySuite extends QueryTest
df.collect()
val exchanges = collect(df.queryExecution.executedPlan) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(exchanges.size === 1)
}
@@ -2674,22 +2681,32 @@ class SubquerySuite extends QueryTest
}
}
- test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery") {
+ test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery",
+ IgnoreComet("TODO: ignore for first stage of 4.0, " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
def checkFileSourceScan(query: String, answer: Seq[Row]): Unit = {
val df = sql(query)
checkAnswer(df, answer)
- val fileSourceScanExec = collect(df.queryExecution.executedPlan) {
- case f: FileSourceScanExec => f
+ val dataSourceScanExec = collect(df.queryExecution.executedPlan) {
+ case f: FileSourceScanLike => f
+ case c: CometScanExec => c
}
sparkContext.listenerBus.waitUntilEmpty()
- assert(fileSourceScanExec.size === 1)
- val scalarSubquery = fileSourceScanExec.head.dataFilters.flatMap(_.collect {
- case s: ScalarSubquery => s
- })
+ assert(dataSourceScanExec.size === 1)
+ val scalarSubquery = dataSourceScanExec.head match {
+ case f: FileSourceScanLike =>
+ f.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ case c: CometScanExec =>
+ c.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ }
assert(scalarSubquery.length === 1)
assert(scalarSubquery.head.plan.isInstanceOf[ReusedSubqueryExec])
- assert(fileSourceScanExec.head.metrics("numFiles").value === 1)
- assert(fileSourceScanExec.head.metrics("numOutputRows").value === answer.size)
+ assert(dataSourceScanExec.head.metrics("numFiles").value === 1)
+ assert(dataSourceScanExec.head.metrics("numOutputRows").value === answer.size)
}
withTable("t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
index fee375db10a..8c2c24e2c5f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
@@ -33,7 +33,9 @@ import org.apache.spark.sql.types._
import org.apache.spark.types.variant._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
-class VariantShreddingSuite extends QueryTest with SharedSparkSession with ParquetTest {
+class VariantShreddingSuite extends QueryTest with SharedSparkSession with ParquetTest
+ // TODO enable tests once https://github.com/apache/datafusion-comet/issues/2209 is fixed
+ with IgnoreCometSuite {
def parseJson(s: String): VariantVal = {
val v = VariantBuilder.parseJson(s, false)
new VariantVal(v.getValue, v.getMetadata)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
index 11e9547dfc5..be9ae40ab3d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.collation
import scala.jdk.CollectionConverters.MapHasAsJava
import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.comet.{CometBroadcastHashJoinExec, CometHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
@@ -55,7 +56,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => assert(isSortMergeForced)
+ case _: CometSortMergeJoinExec => assert(isSortMergeForced)
case _: HashJoin => assert(!isSortMergeForced)
+ case _: CometHashJoinExec | _: CometBroadcastHashJoinExec => assert(!isSortMergeForced)
}.nonEmpty
)
}
@@ -1505,7 +1508,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}
- test("hash join should be used for collated strings if sort merge join is not forced") {
+ test("hash join should be used for collated strings if sort merge join is not forced",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
val t1 = "T_1"
val t2 = "T_2"
@@ -1611,6 +1616,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
} else {
assert(!collectFirst(queryPlan) {
case b: BroadcastHashJoinExec => b.leftKeys.head
+ case b: CometBroadcastHashJoinExec => b.leftKeys.head
}.head.isInstanceOf[ArrayTransform])
}
}
@@ -1676,6 +1682,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
} else {
assert(!collectFirst(queryPlan) {
case b: BroadcastHashJoinExec => b.leftKeys.head
+ case b: CometBroadcastHashJoinExec => b.leftKeys.head
}.head.isInstanceOf[ArrayTransform])
}
}
@@ -1815,7 +1822,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}
- test("rewrite with collationkey shouldn't disrupt multiple join conditions") {
+ test("rewrite with collationkey shouldn't disrupt multiple join conditions",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
val t1 = "T_1"
val t2 = "T_2"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 3eeed2e4175..9f21d547c1c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -26,6 +26,7 @@ import test.org.apache.spark.sql.connector._
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
@@ -36,7 +37,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
import org.apache.spark.sql.execution.SortExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
-import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
@@ -278,13 +279,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val groupByColJ = df.groupBy($"j").agg(sum($"i"))
checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
assert(collectFirst(groupByColJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined)
val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*"))
checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined)
}
}
@@ -344,10 +345,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val (shuffleExpected, sortExpected) = groupByExpects
assert(collectFirst(groupBy.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined === shuffleExpected)
assert(collectFirst(groupBy.queryExecution.executedPlan) {
case e: SortExec => e
+ case c: CometSortExec => c
}.isDefined === sortExpected)
}
@@ -362,10 +364,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val (shuffleExpected, sortExpected) = windowFuncExpects
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
+ case e: ShuffleExchangeLike => e
}.isDefined === shuffleExpected)
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
case e: SortExec => e
+ case c: CometSortExec => c
}.isDefined === sortExpected)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
index 2a0ab21ddb0..6030e7c2b9b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
@@ -188,7 +189,11 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
val df = spark.read.format(format).load(path.getCanonicalPath)
checkAnswer(df, inputData.toDF())
assert(
- df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec]))
+ df.queryExecution.executedPlan.exists {
+ case _: FileSourceScanExec | _: CometScanExec | _: CometNativeScanExec => true
+ case _ => false
+ }
+ )
}
} finally {
spark.listenerManager.unregister(listener)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index c73e8e16fbb..88cd0d47da3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -20,10 +20,11 @@ import java.sql.Timestamp
import java.util.Collections
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, IgnoreComet, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.comet.CometSortMergeJoinExec
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.distributions.Distributions
@@ -32,7 +33,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
@@ -305,13 +306,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}
- private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
+ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = {
// here we skip collecting shuffle operators that are not associated with SMJ
collect(plan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c.originalPlan
}.flatMap(smj =>
collect(smj) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
})
}
@@ -370,7 +372,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0)))
}
- test("SPARK-48655: order by on partition keys should not introduce additional shuffle") {
+ test("SPARK-48655: order by on partition keys should not introduce additional shuffle",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
val items_partitions = Array(identity("price"), identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index f62e092138a..c0404bfe85e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -21,7 +21,7 @@ package org.apache.spark.sql.connector
import java.sql.Date
import java.util.Collections
-import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, IgnoreCometSuite, Row}
import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical
@@ -45,7 +45,8 @@ import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.tags.SlowSQLTest
@SlowSQLTest
-class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase {
+class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase
+ with IgnoreCometSuite {
import testImplicits._
before {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 04d33ecd3d5..450df347297 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -31,7 +31,7 @@ import org.mockito.Mockito.{mock, spy, when}
import org.scalatest.time.SpanSugar._
import org.apache.spark._
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, KryoData, QueryTest, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, IgnoreComet, KryoData, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
@@ -267,7 +267,8 @@ class QueryExecutionErrorsSuite
}
test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " +
- "compatibility with Spark 2.4/3.2 in reading/writing dates") {
+ "compatibility with Spark 2.4/3.2 in reading/writing dates",
+ IgnoreComet("Comet doesn't completely support datetime rebase mode yet")) {
// Fail to read ancient datetime values.
withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
index 418ca3430bb..eb8267192f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
@@ -23,7 +23,7 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.internal.SQLConf
@@ -195,7 +195,7 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest {
}
}
- test("FileScan description") {
+ test("FileScan description", IgnoreComet("Comet doesn't use BatchScan")) {
Seq("json", "orc", "parquet").foreach { format =>
withTempPath { path =>
val dir = path.getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
index d1b11a74cf3..08087c80201 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{Dataset, QueryTest}
+import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest}
import org.apache.spark.sql.IntegratedUDFTestUtils._
+import org.apache.spark.sql.comet.CometCollectLimitExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.rand
import org.apache.spark.sql.internal.SQLConf
@@ -39,7 +40,7 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
private def assertHasCollectLimitExec(plan: SparkPlan): Unit = {
assert(find(plan) {
- case _: CollectLimitExec => true
+ case _: CollectLimitExec | _: CometCollectLimitExec => true
case _ => false
}.isDefined)
}
@@ -77,7 +78,9 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
assert(!hasLocalSort(physicalPlan))
}
- test("root LIMIT preserves data ordering with CollectLimitExec") {
+ test("root LIMIT preserves data ordering with CollectLimitExec",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1") {
val df = spark.range(10).orderBy($"id" % 8).limit(2)
df.collect()
@@ -88,7 +91,9 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
}
}
- test("middle LIMIT preserves data ordering with the extra sort") {
+ test("middle LIMIT preserves data ordering with the extra sort",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
withSQLConf(
SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1",
// To trigger the bug, we have to disable the coalescing optimization. Otherwise we use only
@@ -117,7 +122,9 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
assert(!hasLocalSort(physicalPlan))
}
- test("middle OFFSET preserves data ordering with the extra sort") {
+ test("middle OFFSET preserves data ordering with the extra sort",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
val df = 1.to(10).map(v => v -> v).toDF("c1", "c2").orderBy($"c1" % 8)
verifySortAdded(df.offset(2))
verifySortAdded(df.filter($"c2" > rand()).offset(2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
index 743ec41dbe7..9f30d6c8e04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
@@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv
case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child)
case p: ProjectExec => isScanPlanTree(p.child)
case f: FilterExec => isScanPlanTree(f.child)
+ // Comet produces scan plan tree like:
+ // ColumnarToRow
+ // +- ReusedExchange
+ case _: ReusedExchangeExec => false
case _: LeafExecNode => true
case _ => false
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 1400ee25f43..5b016c3f9c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, DataFrame, Row}
+import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -36,7 +36,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
-class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
+// Ignore this suite when Comet is enabled. This suite tests the Spark planner and Comet planner
+// comes out with too many difference. Simply ignoring this suite for now.
+class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper with IgnoreCometSuite {
import testImplicits._
setupTestData()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index 47d5ff67b84..8dc8f65d4b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -20,7 +20,7 @@ import scala.collection.mutable
import scala.io.Source
import scala.util.Try
-import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, FastOperator}
+import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, FastOperator, IgnoreComet}
import org.apache.spark.sql.catalyst.{QueryPlanningTracker, QueryPlanningTrackerCallback, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{CurrentNamespace, UnresolvedFunction, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.{Alias, UnsafeRow}
@@ -400,7 +400,7 @@ class QueryExecutionSuite extends SharedSparkSession {
}
}
- test("SPARK-47289: extended explain info") {
+ test("SPARK-47289: extended explain info", IgnoreComet("Comet plan extended info is different")) {
val concat = new PlanStringConcat()
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
index b5bac8079c4..9420dbdb936 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.comet.CometConf
+
+import org.apache.spark.sql.{DataFrame, IgnoreComet, QueryTest, Row}
+import org.apache.spark.sql.comet.CometProjectExec
import org.apache.spark.sql.connector.SimpleWritableDataSource
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
@@ -34,7 +37,10 @@ abstract class RemoveRedundantProjectsSuiteBase
private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = {
withClue(df.queryExecution) {
val plan = df.queryExecution.executedPlan
- val actual = collectWithSubqueries(plan) { case p: ProjectExec => p }.size
+ val actual = collectWithSubqueries(plan) {
+ case p: ProjectExec => p
+ case p: CometProjectExec => p
+ }.size
assert(actual == expected)
}
}
@@ -112,7 +118,8 @@ abstract class RemoveRedundantProjectsSuiteBase
assertProjectExec(query, 1, 3)
}
- test("join with ordering requirement") {
+ test("join with ordering requirement",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
val query = "select * from (select key, a, c, b from testView) as t1 join " +
"(select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50"
assertProjectExec(query, 2, 2)
@@ -134,12 +141,25 @@ abstract class RemoveRedundantProjectsSuiteBase
val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0")
df.collect()
val plan = df.queryExecution.executedPlan
- val numProjects = collectWithSubqueries(plan) { case p: ProjectExec => p }.length
+ val numProjects = collectWithSubqueries(plan) {
+ case p: ProjectExec => p
+ case p: CometProjectExec => p
+ }.length
// Create a new plan that reverse the GenerateExec output and add a new ProjectExec between
// GenerateExec and its child. This is to test if the ProjectExec is removed, the output of
// the query will be incorrect.
- val newPlan = stripAQEPlan(plan) transform {
+
+ // Comet-specific change to get original Spark plan before applying
+ // a transformation to add a new ProjectExec
+ var sparkPlan: SparkPlan = null
+ withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") {
+ val df = data.selectExpr("a", "b", "key", "explode(array(key, a, b)) as d").filter("d > 0")
+ df.collect()
+ sparkPlan = df.queryExecution.executedPlan
+ }
+
+ val newPlan = stripAQEPlan(sparkPlan) transform {
case g @ GenerateExec(_, requiredChildOutput, _, _, child) =>
g.copy(requiredChildOutput = requiredChildOutput.reverse,
child = ProjectExec(requiredChildOutput.reverse, child))
@@ -151,6 +171,7 @@ abstract class RemoveRedundantProjectsSuiteBase
// The manually added ProjectExec node shouldn't be removed.
assert(collectWithSubqueries(newExecutedPlan) {
case p: ProjectExec => p
+ case p: CometProjectExec => p
}.size == numProjects + 1)
// Check the original plan's output and the new plan's output are the same.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
index 005e764cc30..92ec088efab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning}
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.joins.ShuffledJoin
import org.apache.spark.sql.internal.SQLConf
@@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase
private def checkNumSorts(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
- assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count)
+ assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count)
}
private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
index 47679ed7865..9ffbaecb98e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.internal.SQLConf
@@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase
private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) {
- case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
+ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index aed11badb71..1a365b5aacf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
+import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometNativeColumnarToRowExec}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -134,7 +135,11 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
- df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
+ df.queryExecution.executedPlan.collectFirst {
+ case p: ColumnarToRowExec => p
+ case p: CometColumnarToRowExec => p
+ case p: CometNativeColumnarToRowExec => p
+ }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index a3cfdc5a240..3793b6191bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkException
import org.apache.spark.rdd.MapPartitionsWithEvaluatorRDD
-import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
+import org.apache.spark.sql.{Dataset, IgnoreCometSuite, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
+import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometHashJoinExec, CometSortExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -33,7 +34,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
// Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec
class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
- with DisableAdaptiveExecutionSuite {
+ with DisableAdaptiveExecutionSuite with IgnoreCometSuite {
import testImplicits._
@@ -172,6 +173,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+ case _: CometHashJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4)))
@@ -180,6 +182,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+ case _: CometHashJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
@@ -206,6 +209,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
@@ -216,6 +221,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
@@ -226,6 +233,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
@@ -237,6 +246,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
@@ -258,6 +269,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(rightJoinUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(rightJoinUniqueDf, Seq(Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9),
@@ -269,6 +282,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(leftJoinUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(leftJoinUniqueDf, Seq(Row(0, null), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4)))
assert(leftJoinUniqueDf.count() === 5)
@@ -278,6 +293,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(rightJoinNonUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(rightJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8)))
@@ -287,6 +304,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(leftJoinNonUniqueDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(leftJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null),
@@ -298,6 +317,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(rightJoinWithNonEquiDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(rightJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 7),
Row(1, 10), Row(2, 2), Row(2, 8), Row(null, 3), Row(null, 4), Row(null, 5)))
@@ -308,6 +329,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(leftJoinWithNonEquiDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(leftJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 7),
Row(1, 10), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null)))
@@ -318,6 +341,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoRightJoinsDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoRightJoinsDf, Seq(Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
@@ -327,6 +352,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoLeftJoinsDf.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoLeftJoinsDf,
Seq(Row(0, null, null), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
@@ -343,6 +370,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneLeftOuterJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_outer")
assert(oneLeftOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneLeftOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, null),
Row(5, null), Row(6, null), Row(7, null), Row(8, null), Row(9, null)))
@@ -351,6 +379,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneRightOuterJoinDF = df2.join(df3.hint("SHUFFLE_MERGE"), $"k2" === $"k3", "right_outer")
assert(oneRightOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneRightOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(null, 4),
Row(null, 5)))
@@ -360,6 +389,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, null, 4), Row(5, null, 5),
@@ -375,6 +405,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0), Row(1), Row(2), Row(3)))
@@ -382,8 +413,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi")
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3)))
}
@@ -397,6 +428,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(4), Row(5), Row(6), Row(7), Row(8), Row(9)))
@@ -404,8 +436,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
.join(df3.hint("SHUFFLE_MERGE"), $"k1" === $"k3", "left_anti")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(6), Row(7), Row(8), Row(9)))
}
@@ -538,7 +570,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]))
+ p.asInstanceOf[WholeStageCodegenExec].collect {
+ case _: SortExec => true
+ case _: CometSortExec => true
+ }.nonEmpty))
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
@@ -718,7 +753,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.write.mode(SaveMode.Overwrite).parquet(path)
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
- SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") {
+ SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true",
+ // Disable Comet native execution because this checks wholestage codegen.
+ "spark.comet.exec.enabled" -> "false") {
val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i")
val df = spark.read.parquet(path).selectExpr(projection: _*)
@@ -815,6 +852,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(distinctWithId.queryExecution.executedPlan.exists {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true
+ case WholeStageCodegenExec(
+ ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: CometColumnarToRowExec, _, _))) =>
+ true
case _ => false
})
checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 272be70f9fe..06957694002 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -28,12 +28,14 @@ import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, IgnoreComet, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.classic.Strategy
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike}
@@ -119,6 +121,7 @@ class AdaptiveQueryExecSuite
private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = {
collect(plan) {
case j: BroadcastHashJoinExec => j
+ case j: CometBroadcastHashJoinExec => j.originalPlan.asInstanceOf[BroadcastHashJoinExec]
}
}
@@ -131,30 +134,39 @@ class AdaptiveQueryExecSuite
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec =>
+ assert(j.originalPlan.isInstanceOf[SortMergeJoinExec])
+ j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
}
private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = {
collect(plan) {
case j: ShuffledHashJoinExec => j
+ case j: CometHashJoinExec => j.originalPlan.asInstanceOf[ShuffledHashJoinExec]
}
}
private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
collect(plan) {
case j: BaseJoinExec => j
+ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
+ case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
+ case c: CometBroadcastHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
}
}
private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = {
collect(plan) {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}
}
private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = {
collect(plan) {
case agg: BaseAggregateExec => agg
+ case agg: CometHashAggregateExec => agg.originalPlan.asInstanceOf[BaseAggregateExec]
}
}
@@ -204,6 +216,7 @@ class AdaptiveQueryExecSuite
val parts = rdd.partitions
assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
}
+
assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead))
}
@@ -212,7 +225,7 @@ class AdaptiveQueryExecSuite
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
@@ -228,7 +241,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics
+ // checkNumLocalShuffleReads(adaptivePlan)
}
}
@@ -255,7 +269,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Reuse the parallelism of coalesced shuffle in local shuffle read") {
+ test("Reuse the parallelism of coalesced shuffle in local shuffle read",
+ IgnoreComet("Comet shuffle changes shuffle partition size")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -287,7 +302,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Reuse the default parallelism in local shuffle read") {
+ test("Reuse the default parallelism in local shuffle read",
+ IgnoreComet("Comet shuffle changes shuffle partition size")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -301,7 +317,8 @@ class AdaptiveQueryExecSuite
val localReads = collect(adaptivePlan) {
case read: AQEShuffleReadExec if read.isLocalRead => read
}
- assert(localReads.length == 2)
+ // Comet shuffle changes shuffle metrics
+ assert(localReads.length == 1)
val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD]
val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD]
// the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2
@@ -326,7 +343,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
- assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
+ assert(find(plan) { case p =>
+ p.isInstanceOf[SortMergeJoinExec] || p.isInstanceOf[CometSortMergeJoinExec]
+ }.isDefined)
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
@@ -340,7 +359,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
- assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(find(plan) { case p =>
+ p.isInstanceOf[BroadcastHashJoinExec] || p.isInstanceOf[CometBroadcastHashJoinExec]
+ }.isDefined)
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
@@ -350,7 +371,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Scalar subquery") {
+ test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -365,7 +386,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Scalar subquery in later stages") {
+ test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -381,7 +402,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins") {
+ test("multiple joins", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -426,7 +447,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins with aggregate") {
+ test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -471,7 +492,7 @@ class AdaptiveQueryExecSuite
}
}
- test("multiple joins with aggregate 2") {
+ test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
@@ -517,7 +538,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Exchange reuse") {
+ test("Exchange reuse", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -536,7 +557,7 @@ class AdaptiveQueryExecSuite
}
}
- test("Exchange reuse with subqueries") {
+ test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
@@ -567,7 +588,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics,
+ // so we can't check the number of local shuffle reads.
+ // checkNumLocalShuffleReads(adaptivePlan)
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
@@ -588,7 +611,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReads(adaptivePlan)
+ // Comet shuffle changes shuffle metrics,
+ // so we can't check the number of local shuffle reads.
+ // checkNumLocalShuffleReads(adaptivePlan)
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
@@ -597,7 +622,8 @@ class AdaptiveQueryExecSuite
}
}
- test("Broadcast exchange reuse across subqueries") {
+ test("Broadcast exchange reuse across subqueries",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
@@ -692,7 +718,8 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// There is still a SMJ, and its two shuffles can't apply local read.
- checkNumLocalShuffleReads(adaptivePlan, 2)
+ // Comet shuffle changes shuffle metrics
+ // checkNumLocalShuffleReads(adaptivePlan, 2)
}
}
@@ -814,7 +841,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-29544: adaptive skew join with different join types") {
+ test("SPARK-29544: adaptive skew join with different join types",
+ IgnoreComet("Comet shuffle has different partition metrics")) {
Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") {
findTopLevelSortMergeJoin(plan)
@@ -1087,7 +1115,8 @@ class AdaptiveQueryExecSuite
}
}
- test("metrics of the shuffle read") {
+ test("metrics of the shuffle read",
+ IgnoreComet("Comet shuffle changes the metrics")) {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT key FROM testData GROUP BY key")
@@ -1721,7 +1750,7 @@ class AdaptiveQueryExecSuite
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
assert(collect(adaptivePlan) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}.length == 1)
}
}
@@ -1801,7 +1830,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-33551: Do not use AQE shuffle read for repartition") {
+ test("SPARK-33551: Do not use AQE shuffle read for repartition",
+ IgnoreComet("Comet shuffle changes partition size")) {
def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
find(plan) {
case s: ShuffleExchangeLike =>
@@ -1986,6 +2016,9 @@ class AdaptiveQueryExecSuite
def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = {
assert(collect(ds.queryExecution.executedPlan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ case c: CometShuffleExchangeExec
+ if c.originalPlan.shuffleOrigin == origin &&
+ c.originalPlan.numPartitions == 2 => c
}.size == 1)
ds.collect()
val plan = ds.queryExecution.executedPlan
@@ -1994,6 +2027,9 @@ class AdaptiveQueryExecSuite
}.isEmpty)
assert(collect(plan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ case c: CometShuffleExchangeExec
+ if c.originalPlan.shuffleOrigin == origin &&
+ c.originalPlan.numPartitions == 2 => c
}.size == 1)
checkAnswer(ds, testData)
}
@@ -2150,7 +2186,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35264: Support AQE side shuffled hash join formula") {
+ test("SPARK-35264: Support AQE side shuffled hash join formula",
+ IgnoreComet("Comet shuffle changes the partition size")) {
withTempView("t1", "t2") {
def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
Seq("100", "100000").foreach { size =>
@@ -2236,7 +2273,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") {
+ test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -2335,7 +2373,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
val shuffles1 = collect(adaptive1) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles1.size == 3)
// shuffles1.head is the top-level shuffle under the Aggregate operator
@@ -2348,7 +2386,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
s"JOIN skewData2 ON key1 = key2")
val shuffles2 = collect(adaptive2) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
if (hasRequiredDistribution) {
assert(shuffles2.size == 3)
@@ -2382,7 +2420,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-35794: Allow custom plugin for cost evaluator") {
+ test("SPARK-35794: Allow custom plugin for cost evaluator",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
CostEvaluator.instantiate(
classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
intercept[IllegalArgumentException] {
@@ -2513,7 +2552,8 @@ class AdaptiveQueryExecSuite
}
test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " +
- "resulting in potentially inaccurate data") {
+ "resulting in potentially inaccurate data",
+ IgnoreComet("too many shuffle partitions causes Java heap OOM")) {
withTable("t3") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
@@ -2548,6 +2588,7 @@ class AdaptiveQueryExecSuite
val (_, adaptive) = runAdaptiveAndVerifyResult(query)
assert(adaptive.collect {
case sort: SortExec => sort
+ case sort: CometSortExec => sort
}.size == 1)
val read = collect(adaptive) {
case read: AQEShuffleReadExec => read
@@ -2565,7 +2606,8 @@ class AdaptiveQueryExecSuite
}
}
- test("SPARK-37357: Add small partition factor for rebalance partitions") {
+ test("SPARK-37357: Add small partition factor for rebalance partitions",
+ IgnoreComet("Comet shuffle changes shuffle metrics")) {
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
@@ -2677,7 +2719,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value2 = value3")
val shuffles1 = collect(adaptive1) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles1.size == 4)
val smj1 = findTopLevelSortMergeJoin(adaptive1)
@@ -2688,7 +2730,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value1 = value3")
val shuffles2 = collect(adaptive2) {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffles2.size == 4)
val smj2 = findTopLevelSortMergeJoin(adaptive2)
@@ -2946,6 +2988,7 @@ class AdaptiveQueryExecSuite
}.size == (if (firstAccess) 1 else 0))
assert(collect(initialExecutedPlan) {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.size == (if (firstAccess) 2 else 0))
assert(collect(initialExecutedPlan) {
case i: InMemoryTableScanLike => i
@@ -2958,6 +3001,7 @@ class AdaptiveQueryExecSuite
}.isEmpty)
assert(collect(finalExecutedPlan) {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.isEmpty)
assert(collect(initialExecutedPlan) {
case i: InMemoryTableScanLike => i
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 0a0b23d1e60..5685926250f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Concat
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
@@ -868,6 +869,7 @@ abstract class SchemaPruningSuite
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case scan: FileSourceScanExec => scan.requiredSchema
+ case scan: CometScanExec => scan.requiredSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index 80d771428d9..9327dca6c21 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.datasources
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{IgnoreComet, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.execution.{QueryExecution, SortExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -226,6 +227,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _),
@@ -273,6 +275,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
+ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _),
@@ -306,7 +309,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
}
}
- test("v1 write with AQE changing SMJ to BHJ") {
+ test("v1 write with AQE changing SMJ to BHJ",
+ IgnoreComet("TODO: Comet SMJ to BHJ by AQE")) {
withPlannedWrite { enabled =>
withTable("t") {
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 62f2f2cb10a..feef4bb2928 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
import org.mockito.Mockito.{mock, when}
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.{DataFrame, IgnoreCometSuite, QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.functions.col
@@ -38,7 +38,9 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class BinaryFileFormatSuite extends QueryTest with SharedSparkSession {
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
+class BinaryFileFormatSuite extends QueryTest with SharedSparkSession with IgnoreCometSuite {
import BinaryFileFormat._
private var testDir: String = _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
index cd6f41b4ef4..4b6a17344bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -28,7 +28,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.TestUtils
import org.apache.spark.memory.MemoryMode
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{IgnoreComet, Row}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -201,7 +201,8 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess
}
}
- test("parquet v2 pages - rle encoding for boolean value columns") {
+ test("parquet v2 pages - rle encoding for boolean value columns",
+ IgnoreComet("Comet doesn't support RLE encoding yet")) {
val extraOptions = Map[String, String](
ParquetOutputFormat.WRITER_VERSION -> ParquetProperties.WriterVersion.PARQUET_2_0.toString
)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 6080a5e8e4b..9aa8f49a62b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -1102,7 +1102,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// When a filter is pushed to Parquet, Parquet can apply it to every row.
// So, we can check the number of rows returned from the Parquet
// to make sure our filter pushdown work.
- assert(stripSparkFilter(df).count() == 1)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ assert(stripSparkFilter(df).count() == 1)
+ }
}
}
}
@@ -1505,7 +1509,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("Filters should be pushed down for vectorized Parquet reader at row group level") {
+ test("Filters should be pushed down for vectorized Parquet reader at row group level",
+ IgnoreCometNativeScan("Native scans do not support the tested accumulator")) {
import testImplicits._
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true",
@@ -1587,7 +1592,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// than the total length but should not be a single record.
// Note that, if record level filtering is enabled, it should be a single record.
// If no filter is pushed down to Parquet, it should be the total length of data.
- assert(actual > 1 && actual < data.length)
+ // Only enable Comet test iff it's scan only, since with native execution
+ // `stripSparkFilter` can't remove the native filter
+ if (!isCometEnabled || isCometScanOnly) {
+ assert(actual > 1 && actual < data.length)
+ }
}
}
}
@@ -1614,7 +1623,11 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
// than the total length but should not be a single record.
// Note that, if record level filtering is enabled, it should be a single record.
// If no filter is pushed down to Parquet, it should be the total length of data.
- assert(actual > 1 && actual < data.length)
+ // Only enable Comet test iff it's scan only, since with native execution
+ // `stripSparkFilter` can't remove the native filter
+ if (!isCometEnabled || isCometScanOnly) {
+ assert(actual > 1 && actual < data.length)
+ }
}
}
}
@@ -1706,7 +1719,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
(attr, value) => sources.StringContains(attr, value))
}
- test("filter pushdown - StringPredicate") {
+ test("filter pushdown - StringPredicate", IgnoreCometNativeScan("cannot be pushed down")) {
import testImplicits._
// keep() should take effect on StartsWith/EndsWith/Contains
Seq(
@@ -1750,7 +1763,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("SPARK-17091: Convert IN predicate to Parquet filter push-down") {
+ test("SPARK-17091: Convert IN predicate to Parquet filter push-down",
+ IgnoreCometNativeScan("Comet has different push-down behavior")) {
val schema = StructType(Seq(
StructField("a", IntegerType, nullable = false)
))
@@ -1993,7 +2007,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("Support Parquet column index") {
+ test("Support Parquet column index",
+ IgnoreComet("Comet doesn't support Parquet column index yet")) {
// block 1:
// null count min max
// page-0 0 0 99
@@ -2053,7 +2068,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
- test("SPARK-34562: Bloom filter push down") {
+ test("SPARK-34562: Bloom filter push down",
+ IgnoreCometNativeScan("Native scans do not support the tested accumulator")) {
withTempPath { dir =>
val path = dir.getCanonicalPath
spark.range(100).selectExpr("id * 2 AS id")
@@ -2305,7 +2321,11 @@ class ParquetV1FilterSuite extends ParquetFilterSuite {
assert(pushedParquetFilters.exists(_.getClass === filterClass),
s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.")
- checker(stripSparkFilter(query), expected)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ checker(stripSparkFilter(query), expected)
+ }
} else {
assert(selectedFilters.isEmpty, "There is filter pushed down")
}
@@ -2368,7 +2388,11 @@ class ParquetV2FilterSuite extends ParquetFilterSuite {
assert(pushedParquetFilters.exists(_.getClass === filterClass),
s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.")
- checker(stripSparkFilter(query), expected)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ checker(stripSparkFilter(query), expected)
+ }
case _ => assert(false, "Can not match ParquetTable in the query.")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 4474ec1fd42..97910c4fc3a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -1344,7 +1344,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
- test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings") {
+ test("SPARK-40128 read DELTA_LENGTH_BYTE_ARRAY encoded strings",
+ IgnoreComet("Comet doesn't support DELTA encoding yet")) {
withAllParquetReaders {
checkAnswer(
// "fruit" column in this file is encoded using DELTA_LENGTH_BYTE_ARRAY.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index bba71f1c48d..38c60ee2584 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -996,7 +996,11 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
Seq(Some("A"), Some("A"), None).toDF().repartition(1)
.write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
- checkAnswer(stripSparkFilter(df.where("NOT (value <=> 'A')")), df)
+ // Similar to Spark's vectorized reader, Comet doesn't do row-level filtering but relies
+ // on Spark to apply the data filters after columnar batches are returned
+ if (!isCometEnabled) {
+ checkAnswer(stripSparkFilter(df.where("NOT (value <=> 'A')")), df)
+ }
}
}
}
@@ -1060,7 +1064,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
checkAnswer(readParquet(schema2, path), df)
}
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false",
+ "spark.comet.enabled" -> "false") {
val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
checkAnswer(readParquet(schema1, path), df)
val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
@@ -1084,7 +1089,8 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
val df = sql(s"SELECT 1 a, 123456 b, ${Int.MaxValue.toLong * 10} c, CAST('1.2' AS BINARY) d")
df.write.parquet(path.toString)
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false",
+ "spark.comet.enabled" -> "false") {
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
index 30503af0fab..1491f4bc2d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala
@@ -21,7 +21,7 @@ import java.nio.file.{Files, Paths, StandardCopyOption}
import java.sql.{Date, Timestamp}
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
-import org.apache.spark.sql.{QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY}
+import org.apache.spark.sql.{IgnoreCometSuite, QueryTest, Row, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.LegacyBehaviorPolicy.{CORRECTED, EXCEPTION, LEGACY}
@@ -30,9 +30,11 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType.{INT96,
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.tags.SlowSQLTest
+// Comet is disabled for this suite because it doesn't support datetime rebase mode
abstract class ParquetRebaseDatetimeSuite
extends QueryTest
with ParquetTest
+ with IgnoreCometSuite
with SharedSparkSession {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
index 08fd8a9ecb5..24baf360234 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala
@@ -20,6 +20,7 @@ import java.io.File
import scala.jdk.CollectionConverters._
+import org.apache.comet.CometConf
import org.apache.hadoop.fs.Path
import org.apache.parquet.column.ParquetProperties._
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat}
@@ -27,6 +28,7 @@ import org.apache.parquet.hadoop.ParquetWriter.DEFAULT_BLOCK_SIZE
import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -172,8 +174,31 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession {
testRowIndexGeneration("row index generation", conf)
}
+ private def shouldSkip(conf: RowIndexTestConf): Boolean = {
+ Set(
+ RowIndexTestConf(useFilter = true, useSmallPages = true, useSmallRowGroups = true,
+ useSmallSplits = true),
+ RowIndexTestConf(useFilter = true, useSmallPages = true, useSmallRowGroups = true),
+ RowIndexTestConf(useFilter = true, useSmallRowGroups = true, useSmallSplits = true),
+ RowIndexTestConf(useFilter = true, useSmallRowGroups = true),
+ RowIndexTestConf(useVectorizedReader = false, useFilter = true, useSmallPages = true,
+ useSmallRowGroups = true, useSmallSplits = true),
+ RowIndexTestConf(useVectorizedReader = false, useFilter = true, useSmallPages = true,
+ useSmallRowGroups = true),
+ RowIndexTestConf(useVectorizedReader = false, useFilter = true, useSmallRowGroups = true,
+ useSmallSplits = true),
+ RowIndexTestConf(useVectorizedReader = false, useFilter = true, useSmallRowGroups = true)
+ ).contains(conf)
+ }
+
private def testRowIndexGeneration(label: String, conf: RowIndexTestConf): Unit = {
test (s"$label - ${conf.desc}") {
+
+ assume(!shouldSkip(conf), s"TODO: https://github.com/apache/datafusion-comet/issues/1948 " +
+ s"Skipping failing config: ${conf.desc}")
+
+ // native_datafusion Parquet scan does not support row index generation.
+ assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
withSQLConf(conf.sqlConfs: _*) {
withTempPath { path =>
// Read row index using _metadata.row_index if that is supported by the file format.
@@ -245,6 +270,12 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession {
case f: FileSourceScanExec =>
numPartitions += f.inputRDD.partitions.length
numOutputRows += f.metrics("numOutputRows").value
+ case b: CometScanExec =>
+ numPartitions += b.inputRDD.partitions.length
+ numOutputRows += b.metrics("numOutputRows").value
+ case b: CometBatchScanExec =>
+ numPartitions += b.inputRDD.partitions.length
+ numOutputRows += b.metrics("numOutputRows").value
case _ =>
}
assert(numPartitions > 0)
@@ -303,6 +334,8 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession {
val conf = RowIndexTestConf(useDataSourceV2 = useDataSourceV2)
test(s"invalid row index column type - ${conf.desc}") {
+ // native_datafusion Parquet scan does not support row index generation.
+ assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
withSQLConf(conf.sqlConfs: _*) {
withTempPath{ path =>
val df = spark.range(0, 10, 1, 1).toDF("id")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
index 5c0b7def039..151184bc98c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.comet.CometBatchScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -56,6 +57,7 @@ class ParquetV2SchemaPruningSuite extends ParquetSchemaPruningSuite {
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case scan: BatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema
+ case scan: CometBatchScanExec => scan.scan.asInstanceOf[ParquetScan].readDataSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 0acb21f3e6f..3a7bb73f03c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -27,7 +27,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.Type._
import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
import org.apache.spark.sql.functions.desc
@@ -1037,7 +1037,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
e
}
- test("schema mismatch failure error message for parquet reader") {
+ test("schema mismatch failure error message for parquet reader",
+ IgnoreComet("Comet doesn't work with vectorizedReaderEnabled = false")) {
withTempPath { dir =>
val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false)
val expectedMessage = "Encountered error while reading file"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
index 09ed6955a51..236a4e99824 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
@@ -65,7 +65,9 @@ class ParquetTypeWideningSuite
withClue(
s"with dictionary encoding '$dictionaryEnabled' with timestamp rebase mode " +
s"'$timestampRebaseMode''") {
- withAllParquetWriters {
+ // TODO: Comet cannot read DELTA_BINARY_PACKED created by V2 writer
+ // https://github.com/apache/datafusion-comet/issues/574
+ // withAllParquetWriters {
withTempDir { dir =>
val expected =
writeParquetFiles(dir, values, fromType, dictionaryEnabled, timestampRebaseMode)
@@ -86,7 +88,7 @@ class ParquetTypeWideningSuite
}
}
}
- }
+ // }
}
}
@@ -190,7 +192,8 @@ class ParquetTypeWideningSuite
(Seq("1", "2", Short.MinValue.toString), ShortType, DoubleType),
(Seq("1", "2", Int.MinValue.toString), IntegerType, DoubleType),
(Seq("1.23", "10.34"), FloatType, DoubleType),
- (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType)
+ // TODO: Comet cannot handle older than "1582-10-15"
+ (Seq("2020-01-01", "2020-01-02"/* , "1312-02-27" */), DateType, TimestampNTZType)
)
}
test(s"parquet widening conversion $fromType -> $toType") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
index 458b5dfc0f4..d209f3c85bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
@@ -26,7 +26,7 @@ import org.apache.parquet.hadoop.util.HadoopInputFile
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{IgnoreCometSuite, QueryTest, Row}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.test.SharedSparkSession
@@ -35,7 +35,9 @@ import org.apache.spark.unsafe.types.VariantVal
/**
* Test shredding Variant values in the Parquet reader/writer.
*/
-class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with SharedSparkSession {
+class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with SharedSparkSession
+ // TODO enable tests once https://github.com/apache/datafusion-comet/issues/2209 is fixed
+ with IgnoreCometSuite {
private def testWithTempDir(name: String)(block: File => Unit): Unit = test(name) {
withTempDir { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index b8f3ea3c6f3..bbd44221288 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.debug
import java.io.ByteArrayOutputStream
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.IgnoreComet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -125,7 +126,8 @@ class DebuggingSuite extends DebuggingSuiteBase with DisableAdaptiveExecutionSui
| id LongType: {}""".stripMargin))
}
- test("SPARK-28537: DebugExec cannot debug columnar related queries") {
+ test("SPARK-28537: DebugExec cannot debug columnar related queries",
+ IgnoreComet("Comet does not use FileScan")) {
withTempPath { workDir =>
val workDirPath = workDir.getAbsolutePath
val input = spark.range(5).toDF("id")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 0dd90925d3c..7d53ec845ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -46,8 +46,10 @@ import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.{AccumulatorContext, JsonProtocol}
// Disable AQE because metric info is different with AQE on/off
+// This test suite runs tests against the metrics of physical operators.
+// Disabling it for Comet because the metrics are different with Comet enabled.
class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
- with DisableAdaptiveExecutionSuite {
+ with DisableAdaptiveExecutionSuite with IgnoreCometSuite {
import testImplicits._
/**
@@ -765,7 +767,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
}
}
- test("SPARK-26327: FileSourceScanExec metrics") {
+ test("SPARK-26327: FileSourceScanExec metrics",
+ IgnoreComet("Spark uses row-based Parquet reader while Comet is vectorized")) {
withTable("testDataForScan") {
spark.range(10).selectExpr("id", "id % 3 as p")
.write.partitionBy("p").saveAsTable("testDataForScan")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
index 0ab8691801d..d9125f658ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.sql.catalyst.plans.logical.{ArrowEvalPython, BatchEvalPython, Limit, LocalLimit}
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
@@ -108,6 +109,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
+ case scan: CometScanExec => scan
}
assert(scanNodes.length == 1)
assert(scanNodes.head.output.map(_.name) == Seq("a"))
@@ -120,11 +122,16 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
+ case scan: CometScanExec => scan
}
assert(scanNodes.length == 1)
// $"a" is not null and $"a" > 1
- assert(scanNodes.head.dataFilters.length == 2)
- assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
+ val dataFilters = scanNodes.head match {
+ case scan: FileSourceScanExec => scan.dataFilters
+ case scan: CometScanExec => scan.dataFilters
+ }
+ assert(dataFilters.length == 2)
+ assert(dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
}
}
}
@@ -145,6 +152,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: BatchScanExec => scan
+ case scan: CometBatchScanExec => scan
}
assert(scanNodes.length == 1)
assert(scanNodes.head.output.map(_.name) == Seq("a"))
@@ -157,6 +165,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession {
val scanNodes = query.queryExecution.executedPlan.collect {
case scan: BatchScanExec => scan
+ case scan: CometBatchScanExec => scan
}
assert(scanNodes.length == 1)
// $"a" is not null and $"a" > 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
index 7838e62013d..8fa09652921 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala
@@ -37,8 +37,10 @@ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException,
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.Utils
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
class AsyncProgressTrackingMicroBatchExecutionSuite
- extends StreamTest with BeforeAndAfter with Matchers {
+ extends StreamTest with BeforeAndAfter with Matchers with IgnoreCometSuite {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index c4b09c4b289..75c3437788e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -26,10 +26,11 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
import org.apache.spark.sql.execution.datasources.BucketingUtils
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -103,12 +104,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
}
}
- private def getFileScan(plan: SparkPlan): FileSourceScanExec = {
- val fileScan = collect(plan) { case f: FileSourceScanExec => f }
+ private def getFileScan(plan: SparkPlan): SparkPlan = {
+ val fileScan = collect(plan) {
+ case f: FileSourceScanExec => f
+ case f: CometScanExec => f
+ case f: CometNativeScanExec => f
+ }
assert(fileScan.nonEmpty, plan)
fileScan.head
}
+ private def getBucketScan(plan: SparkPlan): Boolean = getFileScan(plan) match {
+ case fs: FileSourceScanExec => fs.bucketedScan
+ case bs: CometScanExec => bs.bucketedScan
+ case ns: CometNativeScanExec => ns.bucketedScan
+ }
+
// To verify if the bucket pruning works, this function checks two conditions:
// 1) Check if the pruned buckets (before filtering) are empty.
// 2) Verify the final result is the same as the expected one
@@ -157,7 +168,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition)
.queryExecution.executedPlan
val fileScan = getFileScan(planWithoutBucketedScan)
- assert(!fileScan.bucketedScan, s"except no bucketed scan but found\n$fileScan")
+ val bucketedScan = getBucketScan(planWithoutBucketedScan)
+ assert(!bucketedScan, s"except no bucketed scan but found\n$fileScan")
val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -454,28 +466,54 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case b: CometSortMergeJoinExec =>
+ b.originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
} else {
val executedPlan = joined.queryExecution.executedPlan
- assert(executedPlan.isInstanceOf[SortMergeJoinExec])
- executedPlan.asInstanceOf[SortMergeJoinExec]
+ executedPlan match {
+ case s: SortMergeJoinExec => s
+ case ColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometNativeColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}
// check existence of shuffle
assert(
- joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft,
+ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft,
s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}")
assert(
- joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight,
+ joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight,
s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}")
// check existence of sort
assert(
- joinOperator.left.exists(_.isInstanceOf[SortExec]) == sortLeft,
+ joinOperator.left.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] &&
+ op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortLeft,
s"expected sort in the left child to be $sortLeft but found\n${joinOperator.left}")
assert(
- joinOperator.right.exists(_.isInstanceOf[SortExec]) == sortRight,
+ joinOperator.right.exists(op => op.isInstanceOf[SortExec] || op.isInstanceOf[CometExec] &&
+ op.asInstanceOf[CometExec].originalPlan.isInstanceOf[SortExec]) == sortRight,
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")
// check the output partitioning
@@ -838,11 +876,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
val scanDF = spark.table("bucketed_table").select("j")
- assert(!getFileScan(scanDF.queryExecution.executedPlan).bucketedScan)
+ assert(!getBucketScan(scanDF.queryExecution.executedPlan))
checkAnswer(scanDF, df1.select("j"))
val aggDF = spark.table("bucketed_table").groupBy("j").agg(max("k"))
- assert(!getFileScan(aggDF.queryExecution.executedPlan).bucketedScan)
+ assert(!getBucketScan(aggDF.queryExecution.executedPlan))
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1031,15 +1069,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
Seq(true, false).foreach { aqeEnabled =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
val plan = sql(query).queryExecution.executedPlan
- val shuffles = collect(plan) { case s: ShuffleExchangeExec => s }
+ val shuffles = collect(plan) { case s: ShuffleExchangeLike => s }
assert(shuffles.length == expectedNumShuffles)
val scans = collect(plan) {
case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f
+ case b: CometScanExec if b.optionalNumCoalescedBuckets.isDefined => b
+ case b: CometNativeScanExec if b.optionalNumCoalescedBuckets.isDefined => b
}
if (expectedCoalescedNumBuckets.isDefined) {
assert(scans.length == 1)
- assert(scans.head.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ scans.head match {
+ case f: FileSourceScanExec =>
+ assert(f.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ case b: CometScanExec =>
+ assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ case b: CometNativeScanExec =>
+ assert(b.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets)
+ }
} else {
assert(scans.isEmpty)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 95c2fcbd7b5..e2d4a20c5d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import java.io.File
import org.apache.spark.SparkException
+import org.apache.spark.sql.IgnoreCometSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTableType}
import org.apache.spark.sql.catalyst.parser.ParseException
@@ -27,7 +28,10 @@ import org.apache.spark.sql.internal.SQLConf.BUCKETING_MAX_BUCKETS
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
-class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession {
+// For some reason this suite is flaky w/ or w/o Comet when running in Github workflow.
+// Since it isn't related to Comet, we disable it for now.
+class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession
+ with IgnoreCometSuite {
import testImplicits._
protected override lazy val sql = spark.sql _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
index c5c56f081d8..6cc51f93b4f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.sources
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
@@ -68,7 +69,11 @@ abstract class DisableUnnecessaryBucketedScanSuite
def checkNumBucketedScan(query: String, expectedNumBucketedScan: Int): Unit = {
val plan = sql(query).queryExecution.executedPlan
- val bucketedScan = collect(plan) { case s: FileSourceScanExec if s.bucketedScan => s }
+ val bucketedScan = collect(plan) {
+ case s: FileSourceScanExec if s.bucketedScan => s
+ case s: CometScanExec if s.bucketedScan => s
+ case s: CometNativeScanExec if s.bucketedScan => s
+ }
assert(bucketedScan.length == expectedNumBucketedScan)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 9742a004545..4e0417d730a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.catalyst.util.stringToFile
+import org.apache.spark.sql.comet.CometBatchScanExec
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -786,6 +787,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
val fileScan = df.queryExecution.executedPlan.collect {
case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] =>
batch.scan.asInstanceOf[FileScan]
+ case batch: CometBatchScanExec if batch.scan.isInstanceOf[FileScan] =>
+ batch.scan.asInstanceOf[FileScan]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index b0967d5ffdf..3d567f913de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.classic.ClassicConversions._
+import org.apache.spark.sql.comet.CometLocalLimitExec
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
@@ -1118,11 +1119,12 @@ class StreamSuite extends StreamTest {
val localLimits = execPlan.collect {
case l: LocalLimitExec => l
case l: StreamingLocalLimitExec => l
+ case l: CometLocalLimitExec => l
}
require(
localLimits.size == 1,
- s"Cant verify local limit optimization with this plan:\n$execPlan")
+ s"Cant verify local limit optimization ${localLimits.size} with this plan:\n$execPlan")
if (expectStreamingLimit) {
assert(
@@ -1130,7 +1132,8 @@ class StreamSuite extends StreamTest {
s"Local limit was not StreamingLocalLimitExec:\n$execPlan")
} else {
assert(
- localLimits.head.isInstanceOf[LocalLimitExec],
+ localLimits.head.isInstanceOf[LocalLimitExec] ||
+ localLimits.head.isInstanceOf[CometLocalLimitExec],
s"Local limit was not LocalLimitExec:\n$execPlan")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
index b4c4ec7acbf..20579284856 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils
import org.scalatest.Assertions
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec}
import org.apache.spark.sql.functions.count
@@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = query.lastExecution.executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}
val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
@@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}
val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index d3c44dcead3..8096bce4436 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper}
import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId}
import org.apache.spark.sql.functions._
@@ -642,14 +642,28 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
val numPartitions = spark.sessionState.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)
- assert(query.lastExecution.executedPlan.collect {
- case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _,
- ShuffleExchangeExec(opA: HashPartitioning, _, _, _),
- ShuffleExchangeExec(opB: HashPartitioning, _, _, _))
- if partitionExpressionsColumns(opA.expressions) === Seq("a", "b")
- && partitionExpressionsColumns(opB.expressions) === Seq("a", "b")
- && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j
- }.size == 1)
+ val join = query.lastExecution.executedPlan.collect {
+ case j: StreamingSymmetricHashJoinExec => j
+ }.head
+ val opA = join.left.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning]
+ }.head
+ val opB = join.right.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning]
+ }.head
+ assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions)
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index e33d4f1f6ab..ce0a21d1e9d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit}
-import org.apache.spark.sql.execution.exchange.{REQUIRED_BY_STATEFUL_OPERATOR, ReusedExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{REQUIRED_BY_STATEFUL_OPERATOR, ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter}
import org.apache.spark.sql.functions._
@@ -1462,7 +1462,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
CheckAnswer((1, 2), (2, 2), (3, 2)),
Execute { qe =>
val shuffleOpt = qe.lastExecution.executedPlan.collect {
- case s: ShuffleExchangeExec => s
+ case s: ShuffleExchangeLike => s
}
assert(shuffleOpt.nonEmpty, "No shuffle exchange found in the query plan")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index 86c4e49f6f6..2e639e5f38d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -22,7 +22,7 @@ import java.util
import org.scalatest.BeforeAndAfter
-import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
@@ -359,7 +359,8 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter {
}
}
- test("explain with table on DSv1 data source") {
+ test("explain with table on DSv1 data source",
+ IgnoreComet("Comet explain output is different")) {
val tblSourceName = "tbl_src"
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index f0f3f94b811..d64e4e54e22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,13 +27,14 @@ import scala.jdk.CollectionConverters._
import scala.language.implicitConversions
import scala.util.control.NonFatal
+import org.apache.comet.CometConf
import org.apache.hadoop.fs.Path
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfterAll, Suite, Tag}
import org.scalatest.concurrent.Eventually
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, IgnoreCometNativeDataFusion, IgnoreCometNativeIcebergCompat, IgnoreCometNativeScan, Row}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
@@ -42,6 +43,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.classic.{ClassicConversions, ColumnConversions, ColumnNodeToExpressionConverter, DataFrame, Dataset, SparkSession, SQLImplicits}
+import org.apache.spark.sql.comet.{CometFilterExec, CometProjectExec}
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.execution.datasources.DataSourceUtils
@@ -128,7 +130,28 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with
}
}
} else {
- super.test(testName, testTags: _*)(testFun)
+ if (isCometEnabled && testTags.exists(_.isInstanceOf[IgnoreComet])) {
+ ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
+ } else {
+ val cometScanImpl = CometConf.COMET_NATIVE_SCAN_IMPL.get(conf)
+ val isNativeIcebergCompat = cometScanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT ||
+ cometScanImpl == CometConf.SCAN_AUTO
+ val isNativeDataFusion = cometScanImpl == CometConf.SCAN_NATIVE_DATAFUSION ||
+ cometScanImpl == CometConf.SCAN_AUTO
+ if (isCometEnabled && isNativeIcebergCompat &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeIcebergCompat])) {
+ ignore(testName + " (disabled for NATIVE_ICEBERG_COMPAT)", testTags: _*)(testFun)
+ } else if (isCometEnabled && isNativeDataFusion &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeDataFusion])) {
+ ignore(testName + " (disabled for NATIVE_DATAFUSION)", testTags: _*)(testFun)
+ } else if (isCometEnabled && (isNativeDataFusion || isNativeIcebergCompat) &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeScan])) {
+ ignore(testName + " (disabled for NATIVE_DATAFUSION and NATIVE_ICEBERG_COMPAT)",
+ testTags: _*)(testFun)
+ } else {
+ super.test(testName, testTags: _*)(testFun)
+ }
+ }
}
}
@@ -248,8 +271,33 @@ private[sql] trait SQLTestUtilsBase
override protected def converter: ColumnNodeToExpressionConverter = self.spark.converter
}
+ /**
+ * Whether Comet extension is enabled
+ */
+ protected def isCometEnabled: Boolean = SparkSession.isCometEnabled
+
+ /**
+ * Whether to enable ansi mode This is only effective when
+ * [[isCometEnabled]] returns true.
+ */
+ protected def enableCometAnsiMode: Boolean = {
+ val v = System.getenv("ENABLE_COMET_ANSI_MODE")
+ v != null && v.toBoolean
+ }
+
+ /**
+ * Whether Spark should only apply Comet scan optimization. This is only effective when
+ * [[isCometEnabled]] returns true.
+ */
+ protected def isCometScanOnly: Boolean = {
+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
+ v != null && v.toBoolean
+ }
+
protected override def withSQLConf[T](pairs: (String, String)*)(f: => T): T = {
SparkSession.setActiveSession(spark)
+
+
super.withSQLConf(pairs: _*)(f)
}
@@ -451,6 +499,8 @@ private[sql] trait SQLTestUtilsBase
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
+ case CometFilterExec(_, _, _, _, child, _) => child
+ case CometProjectExec(_, _, _, _, CometFilterExec(_, _, _, _, child, _), _) => child
}
spark.internalCreateDataFrame(withoutFilters.execute(), schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index 245219c1756..7d2ef1b9145 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -75,6 +75,31 @@ trait SharedSparkSessionBase
// this rule may potentially block testing of other optimization rules such as
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+ // Enable Comet if `ENABLE_COMET` environment variable is set
+ if (isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
+ .set("spark.comet.parquet.respectFilterPushdown", "true")
+
+ if (!isCometScanOnly) {
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
+ .set("spark.comet.memoryOverhead", "10g")
+ } else {
+ conf
+ .set("spark.comet.exec.enabled", "false")
+ .set("spark.comet.exec.shuffle.enabled", "false")
+ }
+
+ if (enableCometAnsiMode) {
+ conf
+ .set("spark.sql.ansi.enabled", "true")
+ }
+ }
conf.set(
StaticSQLConf.WAREHOUSE_PATH,
conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName)
diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
index 982d57fb287..6017f36c440 100644
--- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
@@ -46,7 +46,7 @@ class SqlResourceWithActualMetricsSuite
import testImplicits._
// Exclude nodes which may not have the metrics
- val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject")
+ val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject", "RowToColumnar")
implicit val formats: DefaultFormats = new DefaultFormats {
override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
index 52abd248f3a..7a199931a08 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.hive.execution.HiveTableScanExec
@@ -35,6 +36,9 @@ abstract class DynamicPartitionPruningHiveScanSuiteBase
case s: FileSourceScanExec => s.partitionFilters.collect {
case d: DynamicPruningExpression => d.child
}
+ case s: CometScanExec => s.partitionFilters.collect {
+ case d: DynamicPruningExpression => d.child
+ }
case h: HiveTableScanExec => h.partitionPruningPred.collect {
case d: DynamicPruningExpression => d.child
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala
index 4b27082e188..09f591dfed3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{IgnoreComet, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -147,11 +147,15 @@ class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveS
// This jar file should not be placed to the classpath.
val jarPath = "src/test/noclasspath/hive-test-udfs.jar"
- assume(new java.io.File(jarPath).exists)
+ // Comet: hive-test-udfs.jar files has been removed from Apache Spark repository
+ // comment out the following line for now
+ // assume(new java.io.File(jarPath).exists)
val jarUrl = s"file://${System.getProperty("user.dir")}/$jarPath"
test("Spark should be able to run Hive UDF using jar regardless of " +
- s"current thread context classloader (${udfInfo.identifier}") {
+ s"current thread context classloader (${udfInfo.identifier}",
+ IgnoreComet("TODO: ignore for first stage of 4.0 " +
+ "https://github.com/apache/datafusion-comet/issues/1948")) {
Utils.withContextClassLoader(Utils.getSparkClassLoader) {
withUserDefinedFunction(udfInfo.funcName -> false) {
val sparkClassLoader = Thread.currentThread().getContextClassLoader
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
index cc7bb193731..06555d48da7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
@@ -818,7 +818,8 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter
}
}
- test("SPARK-30201 HiveOutputWriter standardOI should use ObjectInspectorCopyOption.DEFAULT") {
+ test("SPARK-30201 HiveOutputWriter standardOI should use ObjectInspectorCopyOption.DEFAULT",
+ IgnoreComet("Comet does not support reading non UTF-8 strings")) {
withTable("t1", "t2") {
withTempDir { dir =>
val file = new File(dir, "test.hex")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala
index b67370f6eb9..746b3974b29 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala
@@ -23,14 +23,15 @@ import java.util.concurrent.{Executors, TimeUnit}
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.metrics.source.HiveCatalogMetrics
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{IgnoreCometSuite, QueryTest}
import org.apache.spark.sql.execution.datasources.FileStatusCache
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
class PartitionedTablePerfStatsSuite
- extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach {
+ extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach
+ with IgnoreCometSuite {
override def beforeEach(): Unit = {
super.beforeEach()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index a394d0b7393..d29b3058897 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -53,24 +53,47 @@ object TestHive
new SparkContext(
System.getProperty("spark.sql.test.master", "local[1]"),
"TestSQLContext",
- new SparkConf()
- .set("spark.sql.test", "")
- .set(SQLConf.CODEGEN_FALLBACK.key, "false")
- .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
- .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
- "org.apache.spark.sql.hive.execution.PairSerDe")
- .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath)
- // SPARK-8910
- .set(UI_ENABLED, false)
- .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
- // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes
- // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764.
- .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false")
- // Disable ConvertToLocalRelation for better test coverage. Test cases built on
- // LocalRelation will exercise the optimization rules better by disabling it as
- // this rule may potentially block testing of other optimization rules such as
- // ConstantPropagation etc.
- .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+ {
+ val conf = new SparkConf()
+ .set("spark.sql.test", "")
+ .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+ .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
+ .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
+ "org.apache.spark.sql.hive.execution.PairSerDe")
+ .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath)
+ .set(UI_ENABLED, false)
+ .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
+ .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false")
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+
+ if (SparkSession.isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
+
+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
+ if (v == null || !v.toBoolean) {
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
+ } else {
+ conf
+ .set("spark.comet.exec.enabled", "false")
+ .set("spark.comet.exec.shuffle.enabled", "false")
+ }
+
+ val a = System.getenv("ENABLE_COMET_ANSI_MODE")
+ if (a != null && a.toBoolean) {
+ conf
+ .set("spark.sql.ansi.enabled", "true")
+ }
+ }
+
+ conf
+ }
+
.set(SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD,
sys.env.getOrElse("SPARK_TEST_HIVE_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD",
SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt)