| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.sql.execution.adaptive |
| |
| import java.io.File |
| import java.net.URI |
| |
| import org.apache.logging.log4j.Level |
| import org.scalatest.PrivateMethodTester |
| import org.scalatest.time.SpanSugar._ |
| |
| import org.apache.spark.SparkException |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} |
| import org.apache.spark.shuffle.sort.SortShuffleManager |
| import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Row, SparkSession, Strategy} |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.expressions.Attribute |
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} |
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} |
| import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnaryExecNode, UnionExec} |
| import org.apache.spark.sql.execution.aggregate.BaseAggregateExec |
| import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike} |
| import org.apache.spark.sql.execution.command.DataWritingCommandExec |
| import org.apache.spark.sql.execution.datasources.noop.NoopDataSource |
| import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec |
| import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} |
| import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} |
| import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter |
| import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SparkListenerSQLExecutionStart} |
| import org.apache.spark.sql.functions._ |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode |
| import org.apache.spark.sql.test.SharedSparkSession |
| import org.apache.spark.sql.test.SQLTestData.TestData |
| import org.apache.spark.sql.types.{IntegerType, StructType} |
| import org.apache.spark.sql.util.QueryExecutionListener |
| import org.apache.spark.tags.SlowSQLTest |
| import org.apache.spark.util.ArrayImplicits._ |
| import org.apache.spark.util.Utils |
| |
| @SlowSQLTest |
| class AdaptiveQueryExecSuite |
| extends QueryTest |
| with SharedSparkSession |
| with AdaptiveSparkPlanHelper |
| with PrivateMethodTester { |
| |
| import testImplicits._ |
| |
| setupTestData() |
| |
| private def runAdaptiveAndVerifyResult(query: String, |
| skipCheckAnswer: Boolean = false): (SparkPlan, SparkPlan) = { |
| var finalPlanCnt = 0 |
| var hasMetricsEvent = false |
| val listener = new SparkListener { |
| override def onOtherEvent(event: SparkListenerEvent): Unit = { |
| event match { |
| case SparkListenerSQLAdaptiveExecutionUpdate(_, _, sparkPlanInfo) => |
| if (sparkPlanInfo.simpleString.startsWith( |
| "AdaptiveSparkPlan isFinalPlan=true")) { |
| finalPlanCnt += 1 |
| } |
| case _: SparkListenerSQLAdaptiveSQLMetricUpdates => |
| hasMetricsEvent = true |
| case _ => // ignore other events |
| } |
| } |
| } |
| spark.sparkContext.addSparkListener(listener) |
| |
| val dfAdaptive = sql(query) |
| val planBefore = dfAdaptive.queryExecution.executedPlan |
| assert(planBefore.toString.startsWith("AdaptiveSparkPlan isFinalPlan=false")) |
| val result = dfAdaptive.collect() |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { |
| if (!skipCheckAnswer) { |
| val df = sql(query) |
| checkAnswer(df, result.toImmutableArraySeq) |
| } |
| } |
| val planAfter = dfAdaptive.queryExecution.executedPlan |
| assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) |
| val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan |
| |
| spark.sparkContext.listenerBus.waitUntilEmpty() |
| // AQE will post `SparkListenerSQLAdaptiveExecutionUpdate` twice in case of subqueries that |
| // exist out of query stages. |
| val expectedFinalPlanCnt = adaptivePlan.find(_.subqueries.nonEmpty).map(_ => 2).getOrElse(1) |
| assert(finalPlanCnt == expectedFinalPlanCnt) |
| spark.sparkContext.removeSparkListener(listener) |
| |
| val expectedMetrics = findInMemoryTable(planAfter).nonEmpty || |
| subqueriesAll(planAfter).nonEmpty |
| assert(hasMetricsEvent == expectedMetrics) |
| |
| val exchanges = adaptivePlan.collect { |
| case e: Exchange => e |
| } |
| assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") |
| (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) |
| } |
| |
| private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { |
| collect(plan) { |
| case j: BroadcastHashJoinExec => j |
| } |
| } |
| |
| def findTopLevelBroadcastNestedLoopJoin(plan: SparkPlan): Seq[BaseJoinExec] = { |
| collect(plan) { |
| case j: BroadcastNestedLoopJoinExec => j |
| } |
| } |
| |
| private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { |
| collect(plan) { |
| case j: SortMergeJoinExec => j |
| } |
| } |
| |
| private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = { |
| collect(plan) { |
| case j: ShuffledHashJoinExec => j |
| } |
| } |
| |
| private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { |
| collect(plan) { |
| case j: BaseJoinExec => j |
| } |
| } |
| |
| private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = { |
| collect(plan) { |
| case s: SortExec => s |
| } |
| } |
| |
| private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { |
| collect(plan) { |
| case agg: BaseAggregateExec => agg |
| } |
| } |
| |
| private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = { |
| collect(plan) { |
| case l: CollectLimitExec => l |
| } |
| } |
| |
| private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { |
| collectWithSubqueries(plan) { |
| case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e |
| case BroadcastQueryStageExec(_, e: ReusedExchangeExec, _) => e |
| } |
| } |
| |
| private def findReusedSubquery(plan: SparkPlan): Seq[ReusedSubqueryExec] = { |
| collectWithSubqueries(plan) { |
| case e: ReusedSubqueryExec => e |
| } |
| } |
| |
| private def findInMemoryTable(plan: SparkPlan): Seq[InMemoryTableScanExec] = { |
| collect(plan) { |
| case c: InMemoryTableScanExec |
| if c.relation.cachedPlan.isInstanceOf[AdaptiveSparkPlanExec] => c |
| } |
| } |
| |
| private def checkNumLocalShuffleReads( |
| plan: SparkPlan, numShufflesWithoutLocalRead: Int = 0): Unit = { |
| val numShuffles = collect(plan) { |
| case s: ShuffleQueryStageExec => s |
| }.length |
| |
| val numLocalReads = collect(plan) { |
| case read: AQEShuffleReadExec if read.isLocalRead => read |
| } |
| numLocalReads.foreach { r => |
| val rdd = r.execute() |
| val parts = rdd.partitions |
| assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) |
| } |
| assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead)) |
| } |
| |
| private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = { |
| // repartition obeys initialPartitionNum when adaptiveExecutionEnabled |
| val plan = df.queryExecution.executedPlan |
| assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) |
| val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { |
| case s: ShuffleExchangeExec => s |
| } |
| assert(shuffle.size == 1) |
| assert(shuffle(0).outputPartitioning.numPartitions == numPartition) |
| } |
| |
| test("Change merge join to broadcast join") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a where value = '1'") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| test("Change broadcast join to merge join") { |
| withTable("t1", "t2") { |
| withSQLConf( |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10000", |
| SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "1") { |
| sql("CREATE TABLE t1 USING PARQUET AS SELECT 1 c1") |
| sql("CREATE TABLE t2 USING PARQUET AS SELECT 1 c1") |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT * FROM ( |
| | SELECT distinct c1 from t1 |
| | ) tmp1 JOIN ( |
| | SELECT distinct c1 from t2 |
| | ) tmp2 ON tmp1.c1 = tmp2.c1 |
| |""".stripMargin) |
| assert(findTopLevelBroadcastHashJoin(plan).size == 1) |
| assert(findTopLevelBroadcastHashJoin(adaptivePlan).isEmpty) |
| assert(findTopLevelSortMergeJoin(adaptivePlan).size == 1) |
| } |
| } |
| } |
| |
| test("Reuse the parallelism of coalesced shuffle in local shuffle read") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a where value = '1'") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| val localReads = collect(adaptivePlan) { |
| case read: AQEShuffleReadExec if read.isLocalRead => read |
| } |
| assert(localReads.length == 2) |
| val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] |
| val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] |
| // The pre-shuffle partition size is [0, 0, 0, 72, 0] |
| // We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1 |
| // the final parallelism is |
| // advisoryParallelism = 1 since advisoryParallelism < numMappers |
| // and the partitions length is 1 |
| assert(localShuffleRDD0.getPartitions.length == 1) |
| // The pre-shuffle partition size is [0, 72, 0, 72, 126] |
| // We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3 |
| // the final parallelism is |
| // advisoryParallelism / numMappers: 3/2 = 1 since advisoryParallelism >= numMappers |
| // and the partitions length is 1 * numMappers = 2 |
| assert(localShuffleRDD1.getPartitions.length == 2) |
| } |
| } |
| |
| test("Reuse the default parallelism in local shuffle read") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a where value = '1'") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| val localReads = collect(adaptivePlan) { |
| case read: AQEShuffleReadExec if read.isLocalRead => read |
| } |
| assert(localReads.length == 2) |
| val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] |
| val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] |
| // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 |
| // and the partitions length is 2 * numMappers = 4 |
| assert(localShuffleRDD0.getPartitions.length == 4) |
| // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 |
| // and the partitions length is 2 * numMappers = 4 |
| assert(localShuffleRDD1.getPartitions.length == 4) |
| } |
| } |
| |
| test("Empty stage coalesced to 1-partition RDD") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { |
| val df1 = spark.range(10).withColumn("a", $"id") |
| val df2 = spark.range(10).withColumn("b", $"id") |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val testDf = df1.where($"a" > 10) |
| .join(df2.where($"b" > 10), Seq("id"), "left_outer") |
| .groupBy($"a").count() |
| checkAnswer(testDf, Seq()) |
| val plan = testDf.queryExecution.executedPlan |
| assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) |
| val coalescedReads = collect(plan) { |
| case r: AQEShuffleReadExec => r |
| } |
| assert(coalescedReads.length == 3) |
| coalescedReads.foreach(r => assert(r.partitionSpecs.length == 1)) |
| } |
| |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { |
| val testDf = df1.where($"a" > 10) |
| .join(df2.where($"b" > 10), Seq("id"), "left_outer") |
| .groupBy($"a").count() |
| checkAnswer(testDf, Seq()) |
| val plan = testDf.queryExecution.executedPlan |
| assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) |
| val coalescedReads = collect(plan) { |
| case r: AQEShuffleReadExec => r |
| } |
| assert(coalescedReads.length == 3, s"$plan") |
| coalescedReads.foreach(r => assert(r.isLocalRead || r.partitionSpecs.length == 1)) |
| } |
| } |
| } |
| |
| test("Scalar subquery") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a " + |
| "where value = (SELECT max(a) from testData3)") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| test("Scalar subquery in later stages") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a " + |
| "where (value + a) = (SELECT max(a) from testData3)") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| |
| checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| test("multiple joins") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| """ |
| |WITH t4 AS ( |
| | SELECT * FROM lowercaseData t2 JOIN testData3 t3 ON t2.n = t3.a where t2.n = '1' |
| |) |
| |SELECT * FROM testData |
| |JOIN testData2 t2 ON key = t2.a |
| |JOIN t4 ON t2.b = t4.a |
| |WHERE value = 1 |
| """.stripMargin) |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 3) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 3) |
| |
| // A possible resulting query plan: |
| // BroadcastHashJoin |
| // +- BroadcastExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastHashJoin |
| // +- BroadcastExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastHashJoin |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastExchange |
| // +-LocalShuffleReader* |
| // +- ShuffleExchange |
| |
| // After applied the 'OptimizeShuffleWithLocalRead' rule, we can convert all the four |
| // shuffle read to local shuffle read in the bottom two 'BroadcastHashJoin'. |
| // For the top level 'BroadcastHashJoin', the probe side is not shuffle query stage |
| // and the build side shuffle query stage is also converted to local shuffle read. |
| checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| test("multiple joins with aggregate") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| """ |
| |WITH t4 AS ( |
| | SELECT * FROM lowercaseData t2 JOIN ( |
| | select a, sum(b) from testData3 group by a |
| | ) t3 ON t2.n = t3.a where t2.n = '1' |
| |) |
| |SELECT * FROM testData |
| |JOIN testData2 t2 ON key = t2.a |
| |JOIN t4 ON t2.b = t4.a |
| |WHERE value = 1 |
| """.stripMargin) |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 3) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 3) |
| |
| // A possible resulting query plan: |
| // BroadcastHashJoin |
| // +- BroadcastExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastHashJoin |
| // +- BroadcastExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastHashJoin |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastExchange |
| // +-HashAggregate |
| // +- CoalescedShuffleReader |
| // +- ShuffleExchange |
| |
| // The shuffle added by Aggregate can't apply local read. |
| checkNumLocalShuffleReads(adaptivePlan, 1) |
| } |
| } |
| |
| test("multiple joins with aggregate 2") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| """ |
| |WITH t4 AS ( |
| | SELECT * FROM lowercaseData t2 JOIN ( |
| | select a, max(b) b from testData2 group by a |
| | ) t3 ON t2.n = t3.b |
| |) |
| |SELECT * FROM testData |
| |JOIN testData2 t2 ON key = t2.a |
| |JOIN t4 ON value = t4.a |
| |WHERE value = 1 |
| """.stripMargin) |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 3) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 3) |
| |
| // A possible resulting query plan: |
| // BroadcastHashJoin |
| // +- BroadcastExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastHashJoin |
| // +- BroadcastExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- LocalShuffleReader* |
| // +- ShuffleExchange |
| // +- BroadcastHashJoin |
| // +- Filter |
| // +- HashAggregate |
| // +- CoalescedShuffleReader |
| // +- ShuffleExchange |
| // +- BroadcastExchange |
| // +-LocalShuffleReader* |
| // +- ShuffleExchange |
| |
| // The shuffle added by Aggregate can't apply local read. |
| checkNumLocalShuffleReads(adaptivePlan, 1) |
| } |
| } |
| |
| test("Exchange reuse") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT value FROM testData join testData2 ON key = a " + |
| "join (SELECT value v from testData join testData3 ON key = a) on value = v") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 3) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 2) |
| // There is still a SMJ, and its two shuffles can't apply local read. |
| checkNumLocalShuffleReads(adaptivePlan, 2) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.size == 1) |
| } |
| } |
| |
| test("Exchange reuse with subqueries") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT a FROM testData join testData2 ON key = a " + |
| "where value = (SELECT max(a) from testData join testData2 ON key = a)") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.size == 1) |
| } |
| } |
| |
| test("Exchange reuse across subqueries") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT a FROM testData join testData2 ON key = a " + |
| "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + |
| "and a <= (SELECT max(a) from testData join testData2 ON key = a)") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.nonEmpty) |
| val sub = findReusedSubquery(adaptivePlan) |
| assert(sub.isEmpty) |
| } |
| } |
| |
| test("Subquery reuse") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT a FROM testData join testData2 ON key = a " + |
| "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + |
| "and a <= (SELECT max(a) from testData join testData2 ON key = a)") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.isEmpty) |
| val sub = findReusedSubquery(adaptivePlan) |
| assert(sub.nonEmpty) |
| } |
| } |
| |
| test("Broadcast exchange reuse across subqueries") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", |
| SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT a FROM testData join testData2 ON key = a " + |
| "where value >= (" + |
| "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " + |
| "and a <= (" + |
| "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| checkNumLocalShuffleReads(adaptivePlan) |
| // Even with local shuffle read, the query stage reuse can also work. |
| val ex = findReusedExchange(adaptivePlan) |
| assert(ex.nonEmpty) |
| assert(ex.head.child.isInstanceOf[BroadcastExchangeExec]) |
| val sub = findReusedSubquery(adaptivePlan) |
| assert(sub.isEmpty) |
| } |
| } |
| |
| test("Union/Except/Intersect queries") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| runAdaptiveAndVerifyResult( |
| """ |
| |SELECT * FROM testData |
| |EXCEPT |
| |SELECT * FROM testData2 |
| |UNION ALL |
| |SELECT * FROM testData |
| |INTERSECT ALL |
| |SELECT * FROM testData2 |
| """.stripMargin) |
| } |
| } |
| |
| test("Subquery de-correlation in Union queries") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| withTempView("a", "b") { |
| Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a") |
| Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b") |
| |
| runAdaptiveAndVerifyResult( |
| """ |
| |SELECT id,num,source FROM ( |
| | SELECT id, num, 'a' as source FROM a |
| | UNION ALL |
| | SELECT id, num, 'b' as source FROM b |
| |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) |
| """.stripMargin) |
| } |
| } |
| } |
| |
| test("Avoid plan change if cost is greater") { |
| val origPlan = sql("SELECT * FROM testData " + |
| "join testData2 t2 ON key = t2.a " + |
| "join testData2 t3 on t2.a = t3.a where t2.b = 1").queryExecution.executedPlan |
| |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", |
| SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData " + |
| "join testData2 t2 ON key = t2.a " + |
| "join testData2 t3 on t2.a = t3.a where t2.b = 1") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 2) |
| val smj2 = findTopLevelSortMergeJoin(adaptivePlan) |
| assert(smj2.size == 2, origPlan.toString) |
| } |
| } |
| |
| test("Change merge join to broadcast join without local shuffle read") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT * FROM testData t1 join testData2 t2 |
| |ON t1.key = t2.a join testData3 t3 on t2.a = t3.a |
| |where t1.value = 1 |
| """.stripMargin |
| ) |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 2) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| // There is still a SMJ, and its two shuffles can't apply local read. |
| checkNumLocalShuffleReads(adaptivePlan, 2) |
| } |
| } |
| |
| test("Avoid changing merge join to broadcast join if too many empty partitions on build plan") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { |
| // `testData` is small enough to be broadcast but has empty partition ratio over the config. |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a where value = '1'") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.isEmpty) |
| } |
| // It is still possible to broadcast `testData2`. |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a where value = '1'") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| assert(bhj.head.buildSide == BuildRight) |
| } |
| } |
| } |
| test("SPARK-37753: Allow changing outer join to broadcast join even if too many empty" + |
| " partitions on broadcast side") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { |
| // `testData` is small enough to be broadcast but has empty partition ratio over the config. |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM (select * from testData where value = '1') td" + |
| " right outer join testData2 ON key = a") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| } |
| } |
| } |
| |
| test("SPARK-37753: Inhibit broadcast in left outer join when there are many empty" + |
| " partitions on outer/left side") { |
| // if the right side is completed first and the left side is still being executed, |
| // the right side does not know whether there are many empty partitions on the left side, |
| // so there is no demote, and then the right side is broadcast in the planning stage. |
| // so retry several times here to avoid unit test failure. |
| eventually(timeout(15.seconds), interval(500.milliseconds)) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { |
| // `testData` is small enough to be broadcast but has empty partition ratio over the config. |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM (select * from testData where value = '1') td" + |
| " left outer join testData2 ON key = a") |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.isEmpty) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") { |
| var numStages = 0 |
| val listener = new SparkListener { |
| override def onJobStart(jobStart: SparkListenerJobStart): Unit = { |
| numStages = jobStart.stageInfos.length |
| } |
| } |
| try { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| spark.sparkContext.addSparkListener(listener) |
| spark.range(0, 100, 1, numPartitions = 10).take(1) |
| spark.sparkContext.listenerBus.waitUntilEmpty() |
| // Should be only one stage since there is no shuffle. |
| assert(numStages == 1) |
| } |
| } finally { |
| spark.sparkContext.removeSparkListener(listener) |
| } |
| } |
| |
| test("SPARK-30524: Do not optimize skew join if introduce additional shuffle") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { |
| withTempView("skewData1", "skewData2") { |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 3 as key1", "id as value1") |
| .createOrReplaceTempView("skewData1") |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 1 as key2", "id as value2") |
| .createOrReplaceTempView("skewData2") |
| |
| def checkSkewJoin(query: String, optimizeSkewJoin: Boolean): Unit = { |
| val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) |
| val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) |
| assert(innerSmj.size == 1 && innerSmj.head.isSkewJoin == optimizeSkewJoin) |
| } |
| |
| checkSkewJoin( |
| "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2", true) |
| // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" optimization |
| checkSkewJoin( |
| "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 GROUP BY key1", false) |
| } |
| } |
| } |
| |
| test("SPARK-29544: adaptive skew join with different join types") { |
| Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => |
| def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") { |
| findTopLevelSortMergeJoin(plan) |
| } else { |
| findTopLevelShuffledHashJoin(plan) |
| } |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "100", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { |
| withTempView("skewData1", "skewData2") { |
| spark |
| .range(0, 1000, 1, 10) |
| .select( |
| when($"id" < 250, 249) |
| .when($"id" >= 750, 1000) |
| .otherwise($"id").as("key1"), |
| $"id" as "value1") |
| .createOrReplaceTempView("skewData1") |
| spark |
| .range(0, 1000, 1, 10) |
| .select( |
| when($"id" < 250, 249) |
| .otherwise($"id").as("key2"), |
| $"id" as "value2") |
| .createOrReplaceTempView("skewData2") |
| |
| def checkSkewJoin( |
| joins: Seq[ShuffledJoin], |
| leftSkewNum: Int, |
| rightSkewNum: Int): Unit = { |
| assert(joins.size == 1 && joins.head.isSkewJoin) |
| assert(joins.head.left.collect { |
| case r: AQEShuffleReadExec => r |
| }.head.partitionSpecs.collect { |
| case p: PartialReducerPartitionSpec => p.reducerIndex |
| }.distinct.length == leftSkewNum) |
| assert(joins.head.right.collect { |
| case r: AQEShuffleReadExec => r |
| }.head.partitionSpecs.collect { |
| case p: PartialReducerPartitionSpec => p.reducerIndex |
| }.distinct.length == rightSkewNum) |
| } |
| |
| // skewed inner join optimization |
| val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( |
| s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " + |
| "JOIN skewData2 ON key1 = key2") |
| val inner = getJoinNode(innerAdaptivePlan) |
| checkSkewJoin(inner, 2, 1) |
| |
| // skewed left outer join optimization |
| val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( |
| s"SELECT /*+ $joinHint(skewData2) */ * FROM skewData1 " + |
| "LEFT OUTER JOIN skewData2 ON key1 = key2") |
| val leftJoin = getJoinNode(leftAdaptivePlan) |
| checkSkewJoin(leftJoin, 2, 0) |
| |
| // skewed right outer join optimization |
| val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( |
| s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " + |
| "RIGHT OUTER JOIN skewData2 ON key1 = key2") |
| val rightJoin = getJoinNode(rightAdaptivePlan) |
| checkSkewJoin(rightJoin, 0, 1) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-30291: AQE should catch the exceptions when doing materialize") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| withTable("bucketed_table") { |
| val df1 = |
| (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") |
| df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") |
| val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath |
| val tableDir = new File(warehouseFilePath, "bucketed_table") |
| Utils.deleteRecursively(tableDir) |
| df1.write.parquet(tableDir.getAbsolutePath) |
| |
| val aggregated = spark.table("bucketed_table").groupBy("i").count() |
| val error = intercept[SparkException] { |
| aggregated.count() |
| } |
| assert(error.getErrorClass === "INVALID_BUCKET_FILE") |
| assert(error.getMessage contains "Invalid bucket file") |
| } |
| } |
| } |
| |
| test("SPARK-47148: AQE should avoid to materialize ShuffleQueryStage on the cancellation") { |
| def createJoinedDF(): DataFrame = { |
| val df = spark.range(5).toDF("col") |
| val df2 = spark.range(10).toDF("col").coalesce(2) |
| val df3 = spark.range(15).toDF("col").filter(Symbol("col") >= 2) |
| df.join(df2, Seq("col")).join(df3, Seq("col")) |
| } |
| |
| try { |
| spark.experimental.extraStrategies = TestProblematicCoalesceStrategy :: Nil |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val joinedDF = createJoinedDF() |
| |
| val error = intercept[SparkException] { |
| joinedDF.collect() |
| } |
| assert(error.getMessage() contains "ProblematicCoalesce execution is failed") |
| |
| val adaptivePlan = joinedDF.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] |
| |
| // All QueryStages should be based on ShuffleQueryStageExec |
| val shuffleQueryStageExecs = collect(adaptivePlan) { |
| case sqse: ShuffleQueryStageExec => sqse |
| } |
| assert(shuffleQueryStageExecs.length == 3, s"Physical Plan should include " + |
| s"3 ShuffleQueryStages. Physical Plan: $adaptivePlan") |
| shuffleQueryStageExecs.foreach(sqse => assert(sqse.name.contains("ShuffleQueryStageExec-"))) |
| // First ShuffleQueryStage is materialized so it needs to be canceled. |
| assert(shuffleQueryStageExecs(0).shuffle.isMaterializationStarted(), |
| "Materialization should be started.") |
| // Second ShuffleQueryStage materialization is failed so |
| // it is excluded from the cancellation due to earlyFailedStage. |
| assert(shuffleQueryStageExecs(1).shuffle.isMaterializationStarted(), |
| "Materialization should be started but it is failed.") |
| // Last ShuffleQueryStage is not materialized yet so it does not require |
| // to be canceled and it is just skipped from the cancellation. |
| assert(!shuffleQueryStageExecs(2).shuffle.isMaterializationStarted(), |
| "Materialization should not be started.") |
| } |
| } finally { |
| spark.experimental.extraStrategies = Nil |
| } |
| } |
| |
| test("SPARK-47148: Check if BroadcastQueryStage materialization is started") { |
| def createJoinedDF(): DataFrame = { |
| spark.range(10).toDF("col1").createTempView("t1") |
| spark.range(5).coalesce(2).toDF("col2").createTempView("t2") |
| spark.range(15).toDF("col3").filter(Symbol("col3") >= 2).createTempView("t3") |
| sql("SELECT * FROM (SELECT /*+ BROADCAST(t2) */ * FROM t1 " + |
| "INNER JOIN t2 ON t1.col1 = t2.col2) t JOIN t3 ON t.col1 = t3.col3;") |
| } |
| withTempView("t1", "t2", "t3") { |
| try { |
| spark.experimental.extraStrategies = TestProblematicCoalesceStrategy :: Nil |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val joinedDF = createJoinedDF() |
| |
| val error = intercept[SparkException] { |
| joinedDF.collect() |
| } |
| assert(error.getMessage() contains "ProblematicCoalesce execution is failed") |
| |
| val adaptivePlan = |
| joinedDF.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] |
| |
| // All QueryStages should be based on BroadcastQueryStageExec |
| val broadcastQueryStageExecs = collect(adaptivePlan) { |
| case bqse: BroadcastQueryStageExec => bqse |
| } |
| assert(broadcastQueryStageExecs.length == 2, adaptivePlan) |
| broadcastQueryStageExecs.foreach { bqse => |
| assert(bqse.name.contains("BroadcastQueryStageExec-")) |
| // Both BroadcastQueryStages are materialized at the beginning. |
| assert(bqse.broadcast.isMaterializationStarted(), |
| s"${bqse.name}' s materialization should be started.") |
| } |
| } |
| } finally { |
| spark.experimental.extraStrategies = Nil |
| } |
| } |
| } |
| |
| test("SPARK-30403: AQE should handle InSubquery") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION.key -> "false") { |
| runAdaptiveAndVerifyResult("SELECT * FROM testData LEFT OUTER join testData2" + |
| " ON key = a AND key NOT IN (select a from testData3) where value = '1'" |
| ) |
| } |
| } |
| |
| test("force apply AQE") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { |
| val plan = sql("SELECT * FROM testData").queryExecution.executedPlan |
| assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) |
| } |
| } |
| |
| test("SPARK-30719: do not log warning if intentionally skip AQE") { |
| val testAppender = new LogAppender("aqe logging warning test when skip") |
| withLogAppender(testAppender) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val plan = sql("SELECT * FROM testData").queryExecution.executedPlan |
| assert(!plan.isInstanceOf[AdaptiveSparkPlanExec]) |
| } |
| } |
| assert(!testAppender.loggingEvents |
| .exists(msg => msg.getMessage.getFormattedMessage.contains( |
| s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + |
| s" enabled but is not supported for"))) |
| } |
| |
| test("test log level") { |
| def verifyLog(expectedLevel: Level): Unit = { |
| val logAppender = new LogAppender("adaptive execution") |
| logAppender.setThreshold(expectedLevel) |
| withLogAppender( |
| logAppender, |
| loggerNames = Seq(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)), |
| level = Some(Level.TRACE)) { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| sql("SELECT * FROM testData join testData2 ON key = a where value = '1'").collect() |
| } |
| } |
| Seq("Plan changed", "Final plan").foreach { msg => |
| assert( |
| logAppender.loggingEvents.exists { event => |
| event.getMessage.getFormattedMessage.contains(msg) && event.getLevel == expectedLevel |
| }) |
| } |
| } |
| |
| // Verify default log level |
| verifyLog(Level.DEBUG) |
| |
| // Verify custom log level |
| val levels = Seq( |
| "TRACE" -> Level.TRACE, |
| "trace" -> Level.TRACE, |
| "DEBUG" -> Level.DEBUG, |
| "debug" -> Level.DEBUG, |
| "INFO" -> Level.INFO, |
| "info" -> Level.INFO, |
| "WARN" -> Level.WARN, |
| "warn" -> Level.WARN, |
| "ERROR" -> Level.ERROR, |
| "error" -> Level.ERROR, |
| "deBUG" -> Level.DEBUG) |
| |
| levels.foreach { level => |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_LOG_LEVEL.key -> level._1) { |
| verifyLog(level._2) |
| } |
| } |
| } |
| |
| test("tree string output") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val df = sql("SELECT * FROM testData join testData2 ON key = a where value = '1'") |
| val planBefore = df.queryExecution.executedPlan |
| assert(!planBefore.toString.contains("== Current Plan ==")) |
| assert(!planBefore.toString.contains("== Initial Plan ==")) |
| df.collect() |
| val planAfter = df.queryExecution.executedPlan |
| assert(planAfter.toString.contains("== Final Plan ==")) |
| assert(planAfter.toString.contains("== Initial Plan ==")) |
| } |
| } |
| |
| test("SPARK-31384: avoid NPE in OptimizeSkewedJoin when there's 0 partition plan") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| withTempView("t2") { |
| // create DataFrame with 0 partition |
| spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) |
| .createOrReplaceTempView("t2") |
| // should run successfully without NPE |
| runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 left semi join t2 ON t1.a=t2.b") |
| } |
| } |
| } |
| |
| test("SPARK-34682: AQEShuffleReadExec operating on canonicalized plan") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT key FROM testData GROUP BY key") |
| val reads = collect(adaptivePlan) { |
| case r: AQEShuffleReadExec => r |
| } |
| assert(reads.length == 1) |
| val read = reads.head |
| val c = read.canonicalized.asInstanceOf[AQEShuffleReadExec] |
| // we can't just call execute() because that has separate checks for canonicalized plans |
| checkError( |
| exception = intercept[SparkException] { |
| val doExecute = PrivateMethod[Unit](Symbol("doExecute")) |
| c.invokePrivate(doExecute()) |
| }, |
| errorClass = "INTERNAL_ERROR", |
| parameters = Map("message" -> "operating on canonicalized plan")) |
| } |
| } |
| |
| test("metrics of the shuffle read") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT key FROM testData GROUP BY key") |
| val reads = collect(adaptivePlan) { |
| case r: AQEShuffleReadExec => r |
| } |
| assert(reads.length == 1) |
| val read = reads.head |
| assert(!read.isLocalRead) |
| assert(!read.hasSkewedPartition) |
| assert(read.hasCoalescedPartition) |
| assert(read.metrics.keys.toSeq.sorted == Seq( |
| "numCoalescedPartitions", "numPartitions", "partitionDataSize")) |
| assert(read.metrics("numCoalescedPartitions").value == 1) |
| assert(read.metrics("numPartitions").value == read.partitionSpecs.length) |
| assert(read.metrics("partitionDataSize").value > 0) |
| |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData join testData2 ON key = a where value = '1'") |
| val join = collect(adaptivePlan) { |
| case j: BroadcastHashJoinExec => j |
| }.head |
| assert(join.buildSide == BuildLeft) |
| |
| val reads = collect(join.right) { |
| case r: AQEShuffleReadExec => r |
| } |
| assert(reads.length == 1) |
| val read = reads.head |
| assert(read.isLocalRead) |
| assert(read.metrics.keys.toSeq == Seq("numPartitions")) |
| assert(read.metrics("numPartitions").value == read.partitionSpecs.length) |
| } |
| |
| withSQLConf( |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "100", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1000") { |
| withTempView("skewData1", "skewData2") { |
| spark |
| .range(0, 1000, 1, 10) |
| .select( |
| when($"id" < 250, 249) |
| .when($"id" >= 750, 1000) |
| .otherwise($"id").as("key1"), |
| $"id" as "value1") |
| .createOrReplaceTempView("skewData1") |
| spark |
| .range(0, 1000, 1, 10) |
| .select( |
| when($"id" < 250, 249) |
| .otherwise($"id").as("key2"), |
| $"id" as "value2") |
| .createOrReplaceTempView("skewData2") |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM skewData1 join skewData2 ON key1 = key2") |
| val reads = collect(adaptivePlan) { |
| case r: AQEShuffleReadExec => r |
| } |
| reads.foreach { read => |
| assert(!read.isLocalRead) |
| assert(read.hasCoalescedPartition) |
| assert(read.hasSkewedPartition) |
| assert(read.metrics.contains("numSkewedPartitions")) |
| } |
| assert(reads(0).metrics("numSkewedPartitions").value == 2) |
| assert(reads(0).metrics("numSkewedSplits").value == 11) |
| assert(reads(1).metrics("numSkewedPartitions").value == 1) |
| assert(reads(1).metrics("numSkewedSplits").value == 9) |
| } |
| } |
| } |
| } |
| |
| test("control a plan explain mode in listeners via SQLConf") { |
| |
| def checkPlanDescription(mode: String, expected: Seq[String]): Unit = { |
| var checkDone = false |
| val listener = new SparkListener { |
| override def onOtherEvent(event: SparkListenerEvent): Unit = { |
| event match { |
| case SparkListenerSQLAdaptiveExecutionUpdate(_, planDescription, _) => |
| assert(expected.forall(planDescription.contains)) |
| checkDone = true |
| case _ => // ignore other events |
| } |
| } |
| } |
| spark.sparkContext.addSparkListener(listener) |
| withSQLConf(SQLConf.UI_EXPLAIN_MODE.key -> mode, |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val dfAdaptive = sql("SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'") |
| try { |
| checkAnswer(dfAdaptive, Row(1, "1", 1, 1) :: Row(1, "1", 1, 2) :: Nil) |
| spark.sparkContext.listenerBus.waitUntilEmpty() |
| assert(checkDone) |
| } finally { |
| spark.sparkContext.removeSparkListener(listener) |
| } |
| } |
| } |
| |
| Seq(("simple", Seq("== Physical Plan ==")), |
| ("extended", Seq("== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", |
| "== Optimized Logical Plan ==", "== Physical Plan ==")), |
| ("codegen", Seq("WholeStageCodegen subtrees")), |
| ("cost", Seq("== Optimized Logical Plan ==", "Statistics(sizeInBytes")), |
| ("formatted", Seq("== Physical Plan ==", "Output", "Arguments"))).foreach { |
| case (mode, expected) => |
| checkPlanDescription(mode, expected) |
| } |
| } |
| |
| test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of v2 write commands") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { |
| var plan: SparkPlan = null |
| val listener = new QueryExecutionListener { |
| override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { |
| plan = qe.executedPlan |
| } |
| override def onFailure( |
| funcName: String, qe: QueryExecution, exception: Exception): Unit = {} |
| } |
| spark.listenerManager.register(listener) |
| withTable("t1") { |
| val format = classOf[NoopDataSource].getName |
| Seq((0, 1)).toDF("x", "y").write.format(format).mode("overwrite").save() |
| |
| sparkContext.listenerBus.waitUntilEmpty() |
| assert(plan.isInstanceOf[V2TableWriteExec]) |
| assert(plan.asInstanceOf[V2TableWriteExec].child.isInstanceOf[AdaptiveSparkPlanExec]) |
| |
| spark.listenerManager.unregister(listener) |
| } |
| } |
| } |
| |
| test("SPARK-37287: apply AQE on child plan of a v1 write command") { |
| Seq(true, false).foreach { enabled => |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true", |
| SQLConf.PLANNED_WRITE_ENABLED.key -> enabled.toString) { |
| withTable("t1") { |
| var checkDone = false |
| val listener = new SparkListener { |
| override def onOtherEvent(event: SparkListenerEvent): Unit = { |
| event match { |
| case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => |
| if (enabled) { |
| assert(planInfo.nodeName == "AdaptiveSparkPlan") |
| assert(planInfo.children.size == 1) |
| assert(planInfo.children.head.nodeName == |
| "Execute InsertIntoHadoopFsRelationCommand") |
| } else { |
| assert(planInfo.nodeName == "Execute InsertIntoHadoopFsRelationCommand") |
| } |
| checkDone = true |
| case _ => // ignore other events |
| } |
| } |
| } |
| spark.sparkContext.addSparkListener(listener) |
| try { |
| sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() |
| spark.sparkContext.listenerBus.waitUntilEmpty() |
| assert(checkDone) |
| } finally { |
| spark.sparkContext.removeSparkListener(listener) |
| } |
| } |
| } |
| } |
| } |
| |
| test("AQE should set active session during execution") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val df = spark.range(10).select(sum($"id")) |
| assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) |
| SparkSession.setActiveSession(null) |
| checkAnswer(df, Seq(Row(45))) |
| SparkSession.setActiveSession(spark) // recover the active session. |
| } |
| } |
| |
| test("No deadlock in UI update") { |
| object TestStrategy extends Strategy { |
| def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { |
| case _: Aggregate => |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { |
| spark.range(5).rdd |
| } |
| Nil |
| case _ => Nil |
| } |
| } |
| |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { |
| try { |
| spark.experimental.extraStrategies = TestStrategy :: Nil |
| val df = spark.range(10).groupBy($"id").count() |
| df.collect() |
| } finally { |
| spark.experimental.extraStrategies = Nil |
| } |
| } |
| } |
| |
| test("SPARK-31658: SQL UI should show write commands") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { |
| withTable("t1") { |
| var commands: Seq[SparkPlanInfo] = Seq.empty |
| val listener = new SparkListener { |
| override def onOtherEvent(event: SparkListenerEvent): Unit = { |
| event match { |
| case start: SparkListenerSQLExecutionStart => |
| commands = commands ++ Seq(start.sparkPlanInfo) |
| case _ => // ignore other events |
| } |
| } |
| } |
| spark.sparkContext.addSparkListener(listener) |
| try { |
| sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() |
| spark.sparkContext.listenerBus.waitUntilEmpty() |
| assert(commands.size == 3) |
| assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand") |
| assert(commands(1).nodeName == "AdaptiveSparkPlan") |
| assert(commands(1).children.size == 1) |
| assert(commands(1).children.head.nodeName == "Execute InsertIntoHadoopFsRelationCommand") |
| assert(commands(2).nodeName == "CommandResult") |
| } finally { |
| spark.sparkContext.removeSparkListener(listener) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-31220, SPARK-32056: repartition by expression with AQE") { |
| Seq(true, false).foreach { enableAQE => |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "10") { |
| |
| val df1 = spark.range(10).repartition($"id") |
| val df2 = spark.range(10).repartition($"id" + 1) |
| |
| val partitionsNum1 = df1.rdd.collectPartitions().length |
| val partitionsNum2 = df2.rdd.collectPartitions().length |
| |
| if (enableAQE) { |
| assert(partitionsNum1 < 10) |
| assert(partitionsNum2 < 10) |
| |
| checkInitialPartitionNum(df1, 10) |
| checkInitialPartitionNum(df2, 10) |
| } else { |
| assert(partitionsNum1 === 10) |
| assert(partitionsNum2 === 10) |
| } |
| |
| |
| // Don't coalesce partitions if the number of partitions is specified. |
| val df3 = spark.range(10).repartition(10, $"id") |
| val df4 = spark.range(10).repartition(10) |
| assert(df3.rdd.collectPartitions().length == 10) |
| assert(df4.rdd.collectPartitions().length == 10) |
| } |
| } |
| } |
| |
| test("SPARK-31220, SPARK-32056: repartition by range with AQE") { |
| Seq(true, false).foreach { enableAQE => |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "10") { |
| |
| val df1 = spark.range(10).toDF().repartitionByRange($"id".asc) |
| val df2 = spark.range(10).toDF().repartitionByRange(($"id" + 1).asc) |
| |
| val partitionsNum1 = df1.rdd.collectPartitions().length |
| val partitionsNum2 = df2.rdd.collectPartitions().length |
| |
| if (enableAQE) { |
| assert(partitionsNum1 < 10) |
| assert(partitionsNum2 < 10) |
| |
| checkInitialPartitionNum(df1, 10) |
| checkInitialPartitionNum(df2, 10) |
| } else { |
| assert(partitionsNum1 === 10) |
| assert(partitionsNum2 === 10) |
| } |
| |
| // Don't coalesce partitions if the number of partitions is specified. |
| val df3 = spark.range(10).repartitionByRange(10, $"id".asc) |
| assert(df3.rdd.collectPartitions().length == 10) |
| } |
| } |
| } |
| |
| test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") { |
| Seq(true, false).foreach { enableAQE => |
| withTempView("test") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "10") { |
| |
| spark.range(10).toDF().createTempView("test") |
| |
| val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test") |
| val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test") |
| val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id") |
| val df4 = spark.sql("SELECT * from test CLUSTER BY id") |
| |
| val partitionsNum1 = df1.rdd.collectPartitions().length |
| val partitionsNum2 = df2.rdd.collectPartitions().length |
| val partitionsNum3 = df3.rdd.collectPartitions().length |
| val partitionsNum4 = df4.rdd.collectPartitions().length |
| |
| if (enableAQE) { |
| assert(partitionsNum1 < 10) |
| assert(partitionsNum2 < 10) |
| assert(partitionsNum3 < 10) |
| assert(partitionsNum4 < 10) |
| |
| checkInitialPartitionNum(df1, 10) |
| checkInitialPartitionNum(df2, 10) |
| checkInitialPartitionNum(df3, 10) |
| checkInitialPartitionNum(df4, 10) |
| } else { |
| assert(partitionsNum1 === 10) |
| assert(partitionsNum2 === 10) |
| assert(partitionsNum3 === 10) |
| assert(partitionsNum4 === 10) |
| } |
| |
| // Don't coalesce partitions if the number of partitions is specified. |
| val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test") |
| val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test") |
| assert(df5.rdd.collectPartitions().length == 10) |
| assert(df6.rdd.collectPartitions().length == 10) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-32573: Eliminate NAAJ when BuildSide is HashedRelationWithAllNullKeys") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") |
| val bhj = findTopLevelBroadcastHashJoin(plan) |
| assert(bhj.size == 1) |
| val join = findTopLevelBaseJoin(adaptivePlan) |
| assert(join.isEmpty) |
| checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| test("SPARK-32717: AQEOptimizer should respect excludedRules configuration") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, |
| // This test is a copy of test(SPARK-32573), in order to test the configuration |
| // `spark.sql.adaptive.optimizer.excludedRules` works as expect. |
| SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") |
| val bhj = findTopLevelBroadcastHashJoin(plan) |
| assert(bhj.size == 1) |
| val join = findTopLevelBaseJoin(adaptivePlan) |
| // this is different compares to test(SPARK-32573) due to the rule |
| // `EliminateUnnecessaryJoin` has been excluded. |
| assert(join.nonEmpty) |
| checkNumLocalShuffleReads(adaptivePlan) |
| } |
| } |
| |
| test("SPARK-32649: Eliminate inner and semi join to empty relation") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| Seq( |
| // inner join (small table at right side) |
| "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1", |
| // inner join (small table at left side) |
| "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1", |
| // left semi join |
| "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1" |
| ).foreach(query => { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val join = findTopLevelBaseJoin(adaptivePlan) |
| assert(join.isEmpty) |
| checkNumLocalShuffleReads(adaptivePlan) |
| }) |
| } |
| } |
| |
| test("SPARK-34533: Eliminate left anti join to empty relation") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| Seq( |
| // broadcast non-empty right side |
| ("SELECT /*+ broadcast(testData3) */ * FROM testData LEFT ANTI JOIN testData3", true), |
| // broadcast empty right side |
| ("SELECT /*+ broadcast(emptyTestData) */ * FROM testData LEFT ANTI JOIN emptyTestData", |
| true), |
| // broadcast left side |
| ("SELECT /*+ broadcast(testData) */ * FROM testData LEFT ANTI JOIN testData3", false) |
| ).foreach { case (query, isEliminated) => |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| assert(findTopLevelBaseJoin(plan).size == 1) |
| assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated) |
| } |
| } |
| } |
| |
| test("SPARK-34781: Eliminate left semi/anti join to its left side") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| Seq( |
| // left semi join and non-empty right side |
| ("SELECT * FROM testData LEFT SEMI JOIN testData3", true), |
| // left semi join, non-empty right side and non-empty join condition |
| ("SELECT * FROM testData t1 LEFT SEMI JOIN testData3 t2 ON t1.key = t2.a", false), |
| // left anti join and empty right side |
| ("SELECT * FROM testData LEFT ANTI JOIN emptyTestData", true), |
| // left anti join, empty right side and non-empty join condition |
| ("SELECT * FROM testData t1 LEFT ANTI JOIN emptyTestData t2 ON t1.key = t2.key", true) |
| ).foreach { case (query, isEliminated) => |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| assert(findTopLevelBaseJoin(plan).size == 1) |
| assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated) |
| } |
| } |
| } |
| |
| test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " + |
| "- single join") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| Seq( |
| // left semi join and empty left side |
| ("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT SEMI JOIN " + |
| "testData2 t2 ON t1.key = t2.a", true), |
| // left anti join and empty left side |
| ("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT ANTI JOIN " + |
| "testData2 t2 ON t1.key = t2.a", true), |
| // left outer join and empty left side |
| ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 LEFT JOIN testData2 t2 ON " + |
| "t1.key = t2.a", true), |
| // left outer join and non-empty left side |
| ("SELECT * FROM testData t1 LEFT JOIN testData2 t2 ON " + |
| "t1.key = t2.a", false), |
| // right outer join and empty right side |
| ("SELECT * FROM testData t1 RIGHT JOIN (SELECT * FROM testData2 WHERE b = 0)t2 ON " + |
| "t1.key = t2.a", true), |
| // right outer join and non-empty right side |
| ("SELECT * FROM testData t1 RIGHT JOIN testData2 t2 ON " + |
| "t1.key = t2.a", false), |
| // full outer join and both side empty |
| ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " + |
| "(SELECT * FROM testData2 WHERE b = 0)t2 ON t1.key = t2.a", true), |
| // full outer join and left side empty right side non-empty |
| ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " + |
| "testData2 t2 ON t1.key = t2.a", true) |
| ).foreach { case (query, isEliminated) => |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| assert(findTopLevelBaseJoin(plan).size == 1) |
| assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated, adaptivePlan) |
| } |
| } |
| } |
| |
| test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " + |
| "- multi join") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| Seq( |
| """ |
| |SELECT * FROM testData t1 |
| | JOIN (SELECT * FROM testData2 WHERE b = 0) t2 ON t1.key = t2.a |
| | LEFT JOIN testData2 t3 ON t1.key = t3.a |
| |""".stripMargin, |
| """ |
| |SELECT * FROM (SELECT * FROM testData WHERE key = 0) t1 |
| | LEFT ANTI JOIN testData2 t2 |
| | FULL JOIN (SELECT * FROM testData2 WHERE b = 0) t3 ON t1.key = t3.a |
| |""".stripMargin, |
| """ |
| |SELECT * FROM testData t1 |
| | LEFT SEMI JOIN (SELECT * FROM testData2 WHERE b = 0) |
| | RIGHT JOIN testData2 t3 on t1.key = t3.a |
| |""".stripMargin |
| ).foreach { query => |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| assert(findTopLevelBaseJoin(plan).size == 2) |
| assert(findTopLevelBaseJoin(adaptivePlan).isEmpty) |
| } |
| } |
| } |
| |
| test("SPARK-35585: Support propagate empty relation through project/filter") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( |
| "SELECT key FROM testData WHERE key = 0 ORDER BY key, value") |
| assert(findTopLevelSort(plan1).size == 1) |
| assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) |
| |
| val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( |
| "SELECT key FROM (SELECT * FROM testData WHERE value = 'no_match' ORDER BY key)" + |
| " WHERE key > rand()") |
| assert(findTopLevelSort(plan2).size == 1) |
| assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) |
| } |
| } |
| |
| test("SPARK-35442: Support propagate empty relation through aggregate") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( |
| "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key") |
| assert(!plan1.isInstanceOf[LocalTableScanExec]) |
| assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) |
| |
| val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( |
| "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key limit 1") |
| assert(!plan2.isInstanceOf[LocalTableScanExec]) |
| assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) |
| |
| val (plan3, adaptivePlan3) = runAdaptiveAndVerifyResult( |
| "SELECT count(*) FROM testData WHERE value = 'no_match'") |
| assert(!plan3.isInstanceOf[LocalTableScanExec]) |
| assert(!stripAQEPlan(adaptivePlan3).isInstanceOf[LocalTableScanExec]) |
| } |
| } |
| |
| test("SPARK-35442: Support propagate empty relation through union") { |
| def checkNumUnion(plan: SparkPlan, numUnion: Int): Unit = { |
| assert( |
| collect(plan) { |
| case u: UnionExec => u |
| }.size == numUnion) |
| } |
| |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key |
| |UNION ALL |
| |SELECT key, 1 FROM testData |
| |""".stripMargin) |
| checkNumUnion(plan1, 1) |
| checkNumUnion(adaptivePlan1, 0) |
| assert(!stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) |
| |
| val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key |
| |UNION ALL |
| |SELECT /*+ REPARTITION */ key, 1 FROM testData WHERE value = 'no_match' |
| |""".stripMargin) |
| checkNumUnion(plan2, 1) |
| checkNumUnion(adaptivePlan2, 0) |
| assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) |
| } |
| } |
| |
| test("SPARK-32753: Only copy tags to node with no tags") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| withTempView("v1") { |
| spark.range(10).union(spark.range(10)).createOrReplaceTempView("v1") |
| |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") |
| assert(collect(adaptivePlan) { |
| case s: ShuffleExchangeExec => s |
| }.length == 1) |
| } |
| } |
| } |
| |
| test("Logging plan changes for AQE") { |
| val testAppender = new LogAppender("plan changes") |
| withLogAppender(testAppender) { |
| withSQLConf( |
| SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "INFO", |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| sql("SELECT * FROM testData JOIN testData2 ON key = a " + |
| "WHERE value = (SELECT max(a) FROM testData3)").collect() |
| } |
| Seq("=== Result of Batch AQE Preparations ===", |
| "=== Result of Batch AQE Post Stage Creation ===", |
| "=== Result of Batch AQE Replanning ===", |
| "=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg => |
| assert(testAppender.loggingEvents.exists( |
| _.getMessage.getFormattedMessage.contains(expectedMsg))) |
| } |
| } |
| } |
| |
| test("SPARK-32932: Do not use local shuffle read at final stage on write command") { |
| withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, |
| SQLConf.SHUFFLE_PARTITIONS.key -> "5", |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val data = for ( |
| i <- 1L to 10L; |
| j <- 1L to 3L |
| ) yield (i, j) |
| |
| val df = data.toDF("i", "j").repartition($"j") |
| var noLocalread: Boolean = false |
| val listener = new QueryExecutionListener { |
| override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { |
| stripAQEPlan(qe.executedPlan) match { |
| case plan @ (_: DataWritingCommandExec | _: V2TableWriteExec) => |
| noLocalread = collect(plan) { |
| case exec: AQEShuffleReadExec if exec.isLocalRead => exec |
| }.isEmpty |
| case _ => // ignore other events |
| } |
| } |
| override def onFailure(funcName: String, qe: QueryExecution, |
| exception: Exception): Unit = {} |
| } |
| spark.listenerManager.register(listener) |
| |
| withTable("t") { |
| df.write.partitionBy("j").saveAsTable("t") |
| sparkContext.listenerBus.waitUntilEmpty() |
| assert(noLocalread) |
| noLocalread = false |
| } |
| |
| // Test DataSource v2 |
| val format = classOf[NoopDataSource].getName |
| df.write.format(format).mode("overwrite").save() |
| sparkContext.listenerBus.waitUntilEmpty() |
| assert(noLocalread) |
| noLocalread = false |
| |
| spark.listenerManager.unregister(listener) |
| } |
| } |
| |
| test("SPARK-33494: Do not use local shuffle read for repartition") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val df = spark.table("testData").repartition($"key") |
| df.collect() |
| // local shuffle read breaks partitioning and shouldn't be used for repartition operation |
| // which is specified by users. |
| checkNumLocalShuffleReads(df.queryExecution.executedPlan, numShufflesWithoutLocalRead = 1) |
| } |
| } |
| |
| test("SPARK-33551: Do not use AQE shuffle read for repartition") { |
| def hasRepartitionShuffle(plan: SparkPlan): Boolean = { |
| find(plan) { |
| case s: ShuffleExchangeLike => |
| s.shuffleOrigin == REPARTITION_BY_COL || s.shuffleOrigin == REPARTITION_BY_NUM |
| case _ => false |
| }.isDefined |
| } |
| |
| def checkBHJ( |
| df: Dataset[Row], |
| optimizeOutRepartition: Boolean, |
| probeSideLocalRead: Boolean, |
| probeSideCoalescedRead: Boolean): Unit = { |
| df.collect() |
| val plan = df.queryExecution.executedPlan |
| // There should be only one shuffle that can't do local read, which is either the top shuffle |
| // from repartition, or BHJ probe side shuffle. |
| checkNumLocalShuffleReads(plan, 1) |
| assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition) |
| val bhj = findTopLevelBroadcastHashJoin(plan) |
| assert(bhj.length == 1) |
| |
| // Build side should do local read. |
| val buildSide = find(bhj.head.left)(_.isInstanceOf[AQEShuffleReadExec]) |
| assert(buildSide.isDefined) |
| assert(buildSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead) |
| |
| val probeSide = find(bhj.head.right)(_.isInstanceOf[AQEShuffleReadExec]) |
| if (probeSideLocalRead || probeSideCoalescedRead) { |
| assert(probeSide.isDefined) |
| if (probeSideLocalRead) { |
| assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead) |
| } else { |
| assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition) |
| } |
| } else { |
| assert(probeSide.isEmpty) |
| } |
| } |
| |
| def checkSMJ( |
| df: Dataset[Row], |
| optimizeOutRepartition: Boolean, |
| optimizeSkewJoin: Boolean, |
| coalescedRead: Boolean): Unit = { |
| df.collect() |
| val plan = df.queryExecution.executedPlan |
| assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition) |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.length == 1) |
| assert(smj.head.isSkewJoin == optimizeSkewJoin) |
| val aqeReads = collect(smj.head) { |
| case c: AQEShuffleReadExec => c |
| } |
| if (coalescedRead || optimizeSkewJoin) { |
| assert(aqeReads.length == 2) |
| if (coalescedRead) assert(aqeReads.forall(_.hasCoalescedPartition)) |
| } else { |
| assert(aqeReads.isEmpty) |
| } |
| } |
| |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "5") { |
| val df = sql( |
| """ |
| |SELECT * FROM ( |
| | SELECT * FROM testData WHERE key = 1 |
| |) |
| |RIGHT OUTER JOIN testData2 |
| |ON CAST(value AS INT) = b |
| """.stripMargin) |
| |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| // Repartition with no partition num specified. |
| checkBHJ(df.repartition($"b"), |
| // The top shuffle from repartition is optimized out. |
| optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true) |
| |
| // Repartition with default partition num (5 in test env) specified. |
| checkBHJ(df.repartition(5, $"b"), |
| // The top shuffle from repartition is optimized out |
| // The final plan must have 5 partitions, no optimization can be made to the probe side. |
| optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false) |
| |
| // Repartition with non-default partition num specified. |
| checkBHJ(df.repartition(4, $"b"), |
| // The top shuffle from repartition is not optimized out |
| optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) |
| |
| // Repartition by col and project away the partition cols |
| checkBHJ(df.repartition($"b").select($"key"), |
| // The top shuffle from repartition is not optimized out |
| optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) |
| } |
| |
| // Force skew join |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SKEW_JOIN_ENABLED.key -> "true", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { |
| // Repartition with no partition num specified. |
| checkSMJ(df.repartition($"b"), |
| // The top shuffle from repartition is optimized out. |
| optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true) |
| |
| // Repartition with default partition num (5 in test env) specified. |
| checkSMJ(df.repartition(5, $"b"), |
| // The top shuffle from repartition is optimized out. |
| // The final plan must have 5 partitions, can't do coalesced read. |
| optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false) |
| |
| // Repartition with non-default partition num specified. |
| checkSMJ(df.repartition(4, $"b"), |
| // The top shuffle from repartition is not optimized out. |
| optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) |
| |
| // Repartition by col and project away the partition cols |
| checkSMJ(df.repartition($"b").select($"key"), |
| // The top shuffle from repartition is not optimized out. |
| optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) |
| } |
| } |
| } |
| |
| test("SPARK-34091: Batch shuffle fetch in AQE partition coalescing") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "10", |
| SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "true") { |
| withTable("t1") { |
| spark.range(100).selectExpr("id + 1 as a").write.format("parquet").saveAsTable("t1") |
| val query = "SELECT SUM(a) FROM t1 GROUP BY a" |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| val metricName = SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED |
| val blocksFetchedMetric = collectFirst(adaptivePlan) { |
| case p if p.metrics.contains(metricName) => p.metrics(metricName) |
| } |
| assert(blocksFetchedMetric.isDefined) |
| val blocksFetched = blocksFetchedMetric.get.value |
| withSQLConf(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "false") { |
| val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(query) |
| val blocksFetchedMetric2 = collectFirst(adaptivePlan2) { |
| case p if p.metrics.contains(metricName) => p.metrics(metricName) |
| } |
| assert(blocksFetchedMetric2.isDefined) |
| val blocksFetched2 = blocksFetchedMetric2.get.value |
| assert(blocksFetched < blocksFetched2) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-33933: Materialize BroadcastQueryStage first in AQE") { |
| val testAppender = new LogAppender("aqe query stage materialization order test") |
| testAppender.setThreshold(Level.DEBUG) |
| val df = spark.range(1000).select($"id" % 26, $"id" % 10) |
| .toDF("index", "pv") |
| val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) |
| .toDF("index", "name") |
| val testDf = df.groupBy("index") |
| .agg(sum($"pv").alias("pv")) |
| .join(dim, Seq("index")) |
| val loggerNames = |
| Seq(classOf[BroadcastQueryStageExec].getName, classOf[ShuffleQueryStageExec].getName) |
| withLogAppender(testAppender, loggerNames, level = Some(Level.DEBUG)) { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val result = testDf.collect() |
| assert(result.length == 26) |
| } |
| } |
| val materializeLogs = testAppender.loggingEvents |
| .map(_.getMessage.getFormattedMessage) |
| .filter(_.startsWith("Materialize query stage")) |
| .toArray |
| assert(materializeLogs(0).startsWith("Materialize query stage: BroadcastQueryStageExec-1")) |
| assert(materializeLogs(1).startsWith("Materialize query stage: ShuffleQueryStageExec-0")) |
| } |
| |
| test("SPARK-34899: Use origin plan if we can not coalesce shuffle partition") { |
| def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = { |
| assert(collect(ds.queryExecution.executedPlan) { |
| case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s |
| }.size == 1) |
| ds.collect() |
| val plan = ds.queryExecution.executedPlan |
| assert(collect(plan) { |
| case c: AQEShuffleReadExec => c |
| }.isEmpty) |
| assert(collect(plan) { |
| case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s |
| }.size == 1) |
| checkAnswer(ds, testData) |
| } |
| |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", |
| // Pick a small value so that no coalesce can happen. |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "2") { |
| val df = spark.sparkContext.parallelize( |
| (1 to 100).map(i => TestData(i, i.toString)), 10).toDF() |
| |
| // partition size [1420, 1420] |
| checkNoCoalescePartitions(df.repartition($"key"), REPARTITION_BY_COL) |
| // partition size [1140, 1119] |
| checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS) |
| } |
| } |
| |
| test("SPARK-34980: Support coalesce partition through union") { |
| def checkResultPartition( |
| df: Dataset[Row], |
| numUnion: Int, |
| numShuffleReader: Int, |
| numPartition: Int): Unit = { |
| df.collect() |
| assert(collect(df.queryExecution.executedPlan) { |
| case u: UnionExec => u |
| }.size == numUnion) |
| assert(collect(df.queryExecution.executedPlan) { |
| case r: AQEShuffleReadExec => r |
| }.size === numShuffleReader) |
| assert(df.rdd.partitions.length === numPartition) |
| } |
| |
| Seq(true, false).foreach { combineUnionEnabled => |
| val combineUnionConfig = if (combineUnionEnabled) { |
| SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "" |
| } else { |
| SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> |
| "org.apache.spark.sql.catalyst.optimizer.CombineUnions" |
| } |
| // advisory partition size 1048576 has no special meaning, just a big enough value |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "10", |
| combineUnionConfig) { |
| withTempView("t1", "t2") { |
| spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 2) |
| .toDF().createOrReplaceTempView("t1") |
| spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 4) |
| .toDF().createOrReplaceTempView("t2") |
| |
| // positive test that could be coalesced |
| checkResultPartition( |
| sql(""" |
| |SELECT key, count(*) FROM t1 GROUP BY key |
| |UNION ALL |
| |SELECT * FROM t2 |
| """.stripMargin), |
| numUnion = 1, |
| numShuffleReader = 1, |
| numPartition = 1 + 4) |
| |
| checkResultPartition( |
| sql(""" |
| |SELECT key, count(*) FROM t1 GROUP BY key |
| |UNION ALL |
| |SELECT * FROM t2 |
| |UNION ALL |
| |SELECT * FROM t1 |
| """.stripMargin), |
| numUnion = if (combineUnionEnabled) 1 else 2, |
| numShuffleReader = 1, |
| numPartition = 1 + 4 + 2) |
| |
| checkResultPartition( |
| sql(""" |
| |SELECT /*+ merge(t2) */ t1.key, t2.key FROM t1 JOIN t2 ON t1.key = t2.key |
| |UNION ALL |
| |SELECT key, count(*) FROM t2 GROUP BY key |
| |UNION ALL |
| |SELECT * FROM t1 |
| """.stripMargin), |
| numUnion = if (combineUnionEnabled) 1 else 2, |
| numShuffleReader = 3, |
| numPartition = 1 + 1 + 2) |
| |
| // negative test |
| checkResultPartition( |
| sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2"), |
| numUnion = if (combineUnionEnabled) 1 else 1, |
| numShuffleReader = 0, |
| numPartition = 2 + 4 |
| ) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") { |
| withTable("t") { |
| withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "2", |
| SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { |
| spark.sql("CREATE TABLE t (c1 int) USING PARQUET") |
| val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1") |
| assert( |
| collect(adaptive) { |
| case c @ AQEShuffleReadExec(_, partitionSpecs) if partitionSpecs.length == 1 => |
| assert(c.hasCoalescedPartition) |
| c |
| }.length == 1 |
| ) |
| } |
| } |
| } |
| |
| test("SPARK-35264: Support AQE side broadcastJoin threshold") { |
| withTempView("t1", "t2") { |
| def checkJoinStrategy(shouldBroadcast: Boolean): Unit = { |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val (origin, adaptive) = runAdaptiveAndVerifyResult( |
| "SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1") |
| assert(findTopLevelSortMergeJoin(origin).size == 1) |
| if (shouldBroadcast) { |
| assert(findTopLevelBroadcastHashJoin(adaptive).size == 1) |
| } else { |
| assert(findTopLevelSortMergeJoin(adaptive).size == 1) |
| } |
| } |
| } |
| |
| // t1: 1600 bytes |
| // t2: 160 bytes |
| spark.sparkContext.parallelize( |
| (1 to 100).map(i => TestData(i, i.toString)), 10) |
| .toDF("c1", "c2").createOrReplaceTempView("t1") |
| spark.sparkContext.parallelize( |
| (1 to 10).map(i => TestData(i, i.toString)), 5) |
| .toDF("c1", "c2").createOrReplaceTempView("t2") |
| |
| checkJoinStrategy(false) |
| withSQLConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| checkJoinStrategy(false) |
| } |
| |
| withSQLConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "160") { |
| checkJoinStrategy(true) |
| } |
| } |
| } |
| |
| test("SPARK-35264: Support AQE side shuffled hash join formula") { |
| withTempView("t1", "t2") { |
| def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = { |
| Seq("100", "100000").foreach { size => |
| withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> size) { |
| val (origin1, adaptive1) = runAdaptiveAndVerifyResult( |
| "SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1") |
| assert(findTopLevelSortMergeJoin(origin1).size === 1) |
| if (shouldShuffleHashJoin && size.toInt < 100000) { |
| val shj = findTopLevelShuffledHashJoin(adaptive1) |
| assert(shj.size === 1) |
| assert(shj.head.buildSide == BuildRight) |
| } else { |
| assert(findTopLevelSortMergeJoin(adaptive1).size === 1) |
| } |
| } |
| } |
| // respect user specified join hint |
| val (origin2, adaptive2) = runAdaptiveAndVerifyResult( |
| "SELECT /*+ MERGE(t1) */ t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1") |
| assert(findTopLevelSortMergeJoin(origin2).size === 1) |
| assert(findTopLevelSortMergeJoin(adaptive2).size === 1) |
| } |
| |
| spark.sparkContext.parallelize( |
| (1 to 100).map(i => TestData(i, i.toString)), 10) |
| .toDF("c1", "c2").createOrReplaceTempView("t1") |
| spark.sparkContext.parallelize( |
| (1 to 10).map(i => TestData(i, i.toString)), 5) |
| .toDF("c1", "c2").createOrReplaceTempView("t2") |
| |
| // t1 partition size: [926, 729, 731] |
| // t2 partition size: [318, 120, 0] |
| withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.PREFER_SORTMERGEJOIN.key -> "true") { |
| // check default value |
| checkJoinStrategy(false) |
| withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "400") { |
| checkJoinStrategy(true) |
| } |
| withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "300") { |
| checkJoinStrategy(false) |
| } |
| withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "1000") { |
| checkJoinStrategy(true) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-35650: Coalesce number of partitions by AEQ") { |
| withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { |
| Seq("REPARTITION", "REBALANCE(key)") |
| .foreach {repartition => |
| val query = s"SELECT /*+ $repartition */ * FROM testData" |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| collect(adaptivePlan) { |
| case r: AQEShuffleReadExec => r |
| } match { |
| case Seq(aqeShuffleRead) => |
| assert(aqeShuffleRead.partitionSpecs.size === 1) |
| assert(!aqeShuffleRead.isLocalRead) |
| case _ => |
| fail("There should be a AQEShuffleReadExec") |
| } |
| } |
| } |
| } |
| |
| test("SPARK-35650: Use local shuffle read if can not coalesce number of partitions") { |
| withSQLConf(SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { |
| val query = "SELECT /*+ REPARTITION */ * FROM testData" |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| collect(adaptivePlan) { |
| case r: AQEShuffleReadExec => r |
| } match { |
| case Seq(aqeShuffleRead) => |
| assert(aqeShuffleRead.partitionSpecs.size === 4) |
| assert(aqeShuffleRead.isLocalRead) |
| case _ => |
| fail("There should be a AQEShuffleReadExec") |
| } |
| } |
| } |
| |
| test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") { |
| withTempView("v") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "5", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { |
| |
| spark.sparkContext.parallelize( |
| (1 to 10).map(i => TestData(if (i > 4) 5 else i, i.toString)), 3) |
| .toDF("c1", "c2").createOrReplaceTempView("v") |
| |
| def checkPartitionNumber( |
| query: String, skewedPartitionNumber: Int, totalNumber: Int): Unit = { |
| val (_, adaptive) = runAdaptiveAndVerifyResult(query) |
| val read = collect(adaptive) { |
| case read: AQEShuffleReadExec => read |
| } |
| assert(read.size == 1) |
| assert(read.head.partitionSpecs.count(_.isInstanceOf[PartialReducerPartitionSpec]) == |
| skewedPartitionNumber) |
| assert(read.head.partitionSpecs.size == totalNumber) |
| } |
| |
| withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") { |
| // partition size [0,258,72,72,72] |
| checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4) |
| // partition size [144,72,144,72,72,144,72] |
| checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 6, 7) |
| } |
| |
| // no skewed partition should be optimized |
| withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10000") { |
| checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 0, 1) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-35888: join with a 0-partition table") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { |
| withTempView("t2") { |
| // create a temp view with 0 partition |
| spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) |
| .createOrReplaceTempView("t2") |
| val (_, adaptive) = |
| runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 left semi join t2 ON t1.a=t2.b") |
| val aqeReads = collect(adaptive) { |
| case c: AQEShuffleReadExec => c |
| } |
| assert(aqeReads.length == 2) |
| aqeReads.foreach { c => |
| val stats = c.child.asInstanceOf[QueryStageExec].getRuntimeStatistics |
| assert(stats.sizeInBytes >= 0) |
| assert(stats.rowCount.get >= 0) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-33832: Support optimize skew join even if introduce extra shuffle") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "false", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "10", |
| SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN.key -> "true") { |
| withTempView("skewData1", "skewData2") { |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 3 as key1", "id as value1") |
| .createOrReplaceTempView("skewData1") |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 1 as key2", "id as value2") |
| .createOrReplaceTempView("skewData2") |
| |
| // check if optimized skewed join does not satisfy the required distribution |
| Seq(true, false).foreach { hasRequiredDistribution => |
| Seq(true, false).foreach { hasPartitionNumber => |
| val repartition = if (hasRequiredDistribution) { |
| s"/*+ repartition(${ if (hasPartitionNumber) "10," else ""}key1) */" |
| } else { |
| "" |
| } |
| |
| // check required distribution and extra shuffle |
| val (_, adaptive1) = |
| runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + |
| s"JOIN skewData2 ON key1 = key2 GROUP BY key1") |
| val shuffles1 = collect(adaptive1) { |
| case s: ShuffleExchangeExec => s |
| } |
| assert(shuffles1.size == 3) |
| // shuffles1.head is the top-level shuffle under the Aggregate operator |
| assert(shuffles1.head.shuffleOrigin == ENSURE_REQUIREMENTS) |
| val smj1 = findTopLevelSortMergeJoin(adaptive1) |
| assert(smj1.size == 1 && smj1.head.isSkewJoin) |
| |
| // only check required distribution |
| val (_, adaptive2) = |
| runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + |
| s"JOIN skewData2 ON key1 = key2") |
| val shuffles2 = collect(adaptive2) { |
| case s: ShuffleExchangeExec => s |
| } |
| if (hasRequiredDistribution) { |
| assert(shuffles2.size == 3) |
| val finalShuffle = shuffles2.head |
| if (hasPartitionNumber) { |
| assert(finalShuffle.shuffleOrigin == REPARTITION_BY_NUM) |
| } else { |
| assert(finalShuffle.shuffleOrigin == REPARTITION_BY_COL) |
| } |
| } else { |
| assert(shuffles2.size == 2) |
| } |
| val smj2 = findTopLevelSortMergeJoin(adaptive2) |
| assert(smj2.size == 1 && smj2.head.isSkewJoin) |
| } |
| } |
| } |
| } |
| } |
| |
| test("SPARK-35968: AQE coalescing should not produce too small partitions by default") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (_, adaptive) = |
| runAdaptiveAndVerifyResult("SELECT sum(id) FROM RANGE(10) GROUP BY id % 3") |
| val coalesceRead = collect(adaptive) { |
| case r: AQEShuffleReadExec if r.hasCoalescedPartition => r |
| } |
| assert(coalesceRead.length == 1) |
| // RANGE(10) is a very small dataset and AQE coalescing should produce one partition. |
| assert(coalesceRead.head.partitionSpecs.length == 1) |
| } |
| } |
| |
| test("SPARK-35794: Allow custom plugin for cost evaluator") { |
| CostEvaluator.instantiate( |
| classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) |
| intercept[IllegalArgumentException] { |
| CostEvaluator.instantiate( |
| classOf[InvalidCostEvaluator].getCanonicalName, spark.sparkContext.getConf) |
| } |
| |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { |
| val query = "SELECT * FROM testData join testData2 ON key = a where value = '1'" |
| |
| withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key -> |
| "org.apache.spark.sql.execution.adaptive.SimpleShuffleSortCostEvaluator") { |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) |
| val smj = findTopLevelSortMergeJoin(plan) |
| assert(smj.size == 1) |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.size == 1) |
| checkNumLocalShuffleReads(adaptivePlan) |
| } |
| |
| withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key -> |
| "org.apache.spark.sql.execution.adaptive.InvalidCostEvaluator") { |
| intercept[IllegalArgumentException] { |
| runAdaptiveAndVerifyResult(query) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-36020: Check logical link in remove redundant projects") { |
| withTempView("t") { |
| spark.range(10).selectExpr("id % 10 as key", "cast(id * 2 as int) as a", |
| "cast(id * 3 as int) as b", "array(id, id + 1, id + 3) as c").createOrReplaceTempView("t") |
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "800") { |
| val query = |
| """ |
| |WITH tt AS ( |
| | SELECT key, a, b, explode(c) AS c FROM t |
| |) |
| |SELECT t1.key, t1.c, t2.key, t2.c |
| |FROM (SELECT a, b, c, key FROM tt WHERE a > 1) t1 |
| |JOIN (SELECT a, b, c, key FROM tt) t2 |
| | ON t1.key = t2.key |
| |""".stripMargin |
| val (origin, adaptive) = runAdaptiveAndVerifyResult(query) |
| assert(findTopLevelSortMergeJoin(origin).size == 1) |
| assert(findTopLevelBroadcastHashJoin(adaptive).size == 1) |
| } |
| } |
| } |
| |
| test("SPARK-35874: AQE Shuffle should wait for its subqueries to finish before materializing") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val query = "SELECT b FROM testData2 DISTRIBUTE BY (b, (SELECT max(key) FROM testData))" |
| runAdaptiveAndVerifyResult(query) |
| } |
| } |
| |
| test("SPARK-36032: Use inputPlan instead of currentPhysicalPlan to initialize logical link") { |
| withTempView("v") { |
| spark.sparkContext.parallelize( |
| (1 to 10).map(i => TestData(i, i.toString)), 2) |
| .toDF("c1", "c2").createOrReplaceTempView("v") |
| |
| Seq("-1", "10000").foreach { aqeBhj => |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> aqeBhj, |
| SQLConf.SHUFFLE_PARTITIONS.key -> "1") { |
| val (origin, adaptive) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT * FROM v t1 JOIN ( |
| | SELECT c1 + 1 as c3 FROM v |
| |)t2 ON t1.c1 = t2.c3 |
| |SORT BY c1 |
| """.stripMargin) |
| if (aqeBhj.toInt < 0) { |
| // 1 sort since spark plan has no shuffle for SMJ |
| assert(findTopLevelSort(origin).size == 1) |
| // 2 sorts in SMJ |
| assert(findTopLevelSort(adaptive).size == 2) |
| } else { |
| assert(findTopLevelSort(origin).size == 1) |
| // 1 sort at top node and BHJ has no sort |
| assert(findTopLevelSort(adaptive).size == 1) |
| } |
| } |
| } |
| } |
| } |
| |
| test("SPARK-36424: Support eliminate limits in AQE Optimizer") { |
| withTempView("v") { |
| spark.sparkContext.parallelize( |
| (1 to 10).map(i => TestData(i, if (i > 2) "2" else i.toString)), 2) |
| .toDF("c1", "c2").createOrReplaceTempView("v") |
| |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "3") { |
| val (origin1, adaptive1) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT c2, sum(c1) FROM v GROUP BY c2 LIMIT 5 |
| """.stripMargin) |
| assert(findTopLevelLimit(origin1).size == 1) |
| assert(findTopLevelLimit(adaptive1).isEmpty) |
| |
| // eliminate limit through filter |
| val (origin2, adaptive2) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT c2, sum(c1) FROM v GROUP BY c2 HAVING sum(c1) > 1 LIMIT 5 |
| """.stripMargin) |
| assert(findTopLevelLimit(origin2).size == 1) |
| assert(findTopLevelLimit(adaptive2).isEmpty) |
| |
| // The strategy of Eliminate Limits batch should be fixedPoint |
| val (origin3, adaptive3) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT * FROM (SELECT c1 + c2 FROM (SELECT DISTINCT * FROM v LIMIT 10086)) LIMIT 20 |
| """.stripMargin |
| ) |
| assert(findTopLevelLimit(origin3).size == 1) |
| assert(findTopLevelLimit(adaptive3).isEmpty) |
| } |
| } |
| } |
| |
| test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " + |
| "resulting in potentially inaccurate data") { |
| withTable("t3") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.SHUFFLE_PARTITIONS.key -> (SortShuffleManager |
| .MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1).toString) { |
| sql("CREATE TABLE t3 USING PARQUET AS SELECT id FROM range(2)") |
| val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT id, count(*) |
| |FROM t3 |
| |GROUP BY id |
| |LIMIT 1 |
| |""".stripMargin, skipCheckAnswer = true) |
| // The shuffle stage produces two rows and the limit operator should not been optimized out. |
| assert(findTopLevelLimit(plan).size == 1) |
| assert(findTopLevelLimit(adaptivePlan).size == 1) |
| } |
| } |
| } |
| |
| test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") { |
| withTempView("v") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "1", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { |
| spark.sparkContext.parallelize( |
| (1 to 10).map(i => TestData(if (i > 2) 2 else i, i.toString)), 2) |
| .toDF("c1", "c2").createOrReplaceTempView("v") |
| |
| def checkRebalance(query: String, numShufflePartitions: Int): Unit = { |
| val (_, adaptive) = runAdaptiveAndVerifyResult(query) |
| assert(adaptive.collect { |
| case sort: SortExec => sort |
| }.size == 1) |
| val read = collect(adaptive) { |
| case read: AQEShuffleReadExec => read |
| } |
| assert(read.size == 1) |
| assert(read.head.partitionSpecs.forall(_.isInstanceOf[PartialReducerPartitionSpec])) |
| assert(read.head.partitionSpecs.size == numShufflePartitions) |
| } |
| |
| withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "50") { |
| checkRebalance("SELECT /*+ REBALANCE(c1) */ * FROM v SORT BY c1", 2) |
| checkRebalance("SELECT /*+ REBALANCE */ * FROM v SORT BY c1", 2) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-37357: Add small partition factor for rebalance partitions") { |
| withTempView("v") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "1") { |
| spark.sparkContext.parallelize( |
| (1 to 8).map(i => TestData(if (i > 2) 2 else i, i.toString)), 3) |
| .toDF("c1", "c2").createOrReplaceTempView("v") |
| |
| def checkAQEShuffleReadExists(query: String, exists: Boolean): Unit = { |
| val (_, adaptive) = runAdaptiveAndVerifyResult(query) |
| assert( |
| collect(adaptive) { |
| case read: AQEShuffleReadExec => read |
| }.nonEmpty == exists) |
| } |
| |
| withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "200") { |
| withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.5") { |
| // block size: [88, 97, 97] |
| checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", false) |
| } |
| withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.2") { |
| // block size: [88, 97, 97] |
| checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", true) |
| } |
| } |
| } |
| } |
| } |
| |
| test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly plans BHJ") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584", |
| SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { |
| // Spark estimates a string column as 20 bytes so with 60k rows, these relations should be |
| // estimated at ~120m bytes which is greater than the broadcast join threshold. |
| val joinKeyOne = "00112233445566778899" |
| val joinKeyTwo = "11223344556677889900" |
| Seq.fill(60000)(joinKeyOne).toDF("key") |
| .createOrReplaceTempView("temp") |
| Seq.fill(60000)(joinKeyTwo).toDF("key") |
| .createOrReplaceTempView("temp2") |
| |
| Seq(joinKeyOne).toDF("key").createOrReplaceTempView("smallTemp") |
| spark.sql("SELECT key as newKey FROM temp").persist() |
| |
| // This query is trying to set up a situation where there are three joins. |
| // The first join will join the cached relation with a smaller relation. |
| // The first join is expected to be a broadcast join since the smaller relation will |
| // fit under the broadcast join threshold. |
| // The second join will join the first join with another relation and is expected |
| // to remain as a sort-merge join. |
| // The third join will join the cached relation with another relation and is expected |
| // to remain as a sort-merge join. |
| val query = |
| s""" |
| |SELECT t3.newKey |
| |FROM |
| | (SELECT t1.newKey |
| | FROM (SELECT key as newKey FROM temp) as t1 |
| | JOIN |
| | (SELECT key FROM smallTemp) as t2 |
| | ON t1.newKey = t2.key |
| | ) as t3 |
| | JOIN |
| | (SELECT key FROM temp2) as t4 |
| | ON t3.newKey = t4.key |
| |UNION |
| |SELECT t1.newKey |
| |FROM |
| | (SELECT key as newKey FROM temp) as t1 |
| | JOIN |
| | (SELECT key FROM temp2) as t2 |
| | ON t1.newKey = t2.key |
| |""".stripMargin |
| val df = spark.sql(query) |
| df.collect() |
| val adaptivePlan = df.queryExecution.executedPlan |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) |
| assert(bhj.length == 1) |
| } |
| } |
| |
| test("SPARK-37328: skew join with 3 tables") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", |
| SQLConf.SHUFFLE_PARTITIONS.key -> "10") { |
| withTempView("skewData1", "skewData2", "skewData3") { |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 3 as key1", "id % 3 as value1") |
| .createOrReplaceTempView("skewData1") |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 1 as key2", "id as value2") |
| .createOrReplaceTempView("skewData2") |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 1 as key3", "id as value3") |
| .createOrReplaceTempView("skewData3") |
| |
| // skewedJoin doesn't happen in last stage |
| val (_, adaptive1) = |
| runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "JOIN skewData3 ON value2 = value3") |
| val shuffles1 = collect(adaptive1) { |
| case s: ShuffleExchangeExec => s |
| } |
| assert(shuffles1.size == 4) |
| val smj1 = findTopLevelSortMergeJoin(adaptive1) |
| assert(smj1.size == 2 && smj1.last.isSkewJoin && !smj1.head.isSkewJoin) |
| |
| // Query has two skewJoin in two continuous stages. |
| val (_, adaptive2) = |
| runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "JOIN skewData3 ON value1 = value3") |
| val shuffles2 = collect(adaptive2) { |
| case s: ShuffleExchangeExec => s |
| } |
| assert(shuffles2.size == 4) |
| val smj2 = findTopLevelSortMergeJoin(adaptive2) |
| assert(smj2.size == 2 && smj2.forall(_.isSkewJoin)) |
| } |
| } |
| } |
| |
| test("SPARK-37652: optimize skewed join through union") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { |
| withTempView("skewData1", "skewData2") { |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 3 as key1", "id as value1") |
| .createOrReplaceTempView("skewData1") |
| spark |
| .range(0, 1000, 1, 10) |
| .selectExpr("id % 1 as key2", "id as value2") |
| .createOrReplaceTempView("skewData2") |
| |
| def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums: Int): Unit = { |
| val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) |
| val joins = findTopLevelSortMergeJoin(innerAdaptivePlan) |
| val optimizeSkewJoins = joins.filter(_.isSkewJoin) |
| assert(joins.size == joinNums && optimizeSkewJoins.size == optimizeSkewJoinNums) |
| } |
| |
| // skewJoin union skewJoin |
| checkSkewJoin( |
| "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 = key2", 2, 2) |
| |
| // skewJoin union aggregate |
| checkSkewJoin( |
| "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + |
| "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1) |
| |
| // skewJoin1 union (skewJoin2 join aggregate) |
| // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized |
| checkSkewJoin( |
| "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " + |
| "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " + |
| "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0) |
| } |
| } |
| } |
| |
| test("SPARK-38162: Optimize one row plan in AQE Optimizer") { |
| withTempView("v") { |
| spark.sparkContext.parallelize( |
| (1 to 4).map(i => TestData(i, i.toString)), 2) |
| .toDF("c1", "c2").createOrReplaceTempView("v") |
| |
| // remove sort |
| val (origin1, adaptive1) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT * FROM v where c1 = 1 order by c1, c2 |
| |""".stripMargin) |
| assert(findTopLevelSort(origin1).size == 1) |
| assert(findTopLevelSort(adaptive1).isEmpty) |
| |
| // convert group only aggregate to project |
| val (origin2, adaptive2) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) |
| |""".stripMargin) |
| assert(findTopLevelAggregate(origin2).size == 2) |
| assert(findTopLevelAggregate(adaptive2).isEmpty) |
| |
| // remove distinct in aggregate |
| val (origin3, adaptive3) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) |
| |""".stripMargin) |
| assert(findTopLevelAggregate(origin3).size == 4) |
| assert(findTopLevelAggregate(adaptive3).size == 2) |
| |
| // do not optimize if the aggregate is inside query stage |
| val (origin4, adaptive4) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT distinct c1 FROM v where c1 = 1 |
| |""".stripMargin) |
| assert(findTopLevelAggregate(origin4).size == 2) |
| assert(findTopLevelAggregate(adaptive4).size == 2) |
| |
| val (origin5, adaptive5) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT sum(distinct c1) FROM v where c1 = 1 |
| |""".stripMargin) |
| assert(findTopLevelAggregate(origin5).size == 4) |
| assert(findTopLevelAggregate(adaptive5).size == 4) |
| } |
| } |
| |
| test("SPARK-39551: Invalid plan check - invalid broadcast query stage") { |
| withSQLConf( |
| SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| """ |
| |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1 |
| |INNER JOIN testData2 t2 |
| |ON t1.b = t2.b AND t1.a = 0 |
| |RIGHT OUTER JOIN testData2 t3 |
| |ON t1.a > t3.a |
| |GROUP BY t3.b |
| """.stripMargin |
| ) |
| assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) |
| } |
| } |
| |
| test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { |
| withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") { |
| // partitioning: HashPartitioning |
| // shuffleOrigin: REPARTITION_BY_NUM |
| assert(spark.range(0).repartition(5, $"id").rdd.getNumPartitions == 5) |
| // shuffleOrigin: REPARTITION_BY_COL |
| // The minimum partition number after AQE coalesce is 1 |
| assert(spark.range(0).repartition($"id").rdd.getNumPartitions == 1) |
| // through project |
| assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") |
| .repartition(5, $"c1").select($"c2").rdd.getNumPartitions == 5) |
| |
| // partitioning: RangePartitioning |
| // shuffleOrigin: REPARTITION_BY_NUM |
| // The minimum partition number of RangePartitioner is 1 |
| assert(spark.range(0).repartitionByRange(5, $"id").rdd.getNumPartitions == 1) |
| // shuffleOrigin: REPARTITION_BY_COL |
| assert(spark.range(0).repartitionByRange($"id").rdd.getNumPartitions == 1) |
| |
| // partitioning: RoundRobinPartitioning |
| // shuffleOrigin: REPARTITION_BY_NUM |
| assert(spark.range(0).repartition(5).rdd.getNumPartitions == 5) |
| // shuffleOrigin: REBALANCE_PARTITIONS_BY_NONE |
| assert(spark.range(0).repartition().rdd.getNumPartitions == 0) |
| // through project |
| assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") |
| .repartition(5).select($"c2").rdd.getNumPartitions == 5) |
| |
| // partitioning: SinglePartition |
| assert(spark.range(0).repartition(1).rdd.getNumPartitions == 1) |
| } |
| } |
| |
| test("SPARK-39915: Ensure the output partitioning is user-specified") { |
| withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val df1 = spark.range(1).selectExpr("id as c1") |
| val df2 = spark.range(1).selectExpr("id as c2") |
| val df = df1.join(df2, col("c1") === col("c2")).repartition(3, col("c1")) |
| assert(df.rdd.getNumPartitions == 3) |
| } |
| } |
| |
| test("SPARK-42778: QueryStageExec should respect supportsRowBased") { |
| withSQLConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { |
| withTempView("t") { |
| Seq(1).toDF("c1").createOrReplaceTempView("t") |
| spark.catalog.cacheTable("t") |
| val df = spark.table("t") |
| df.collect() |
| assert(collect(df.queryExecution.executedPlan) { |
| case c: ColumnarToRowExec => c |
| }.isEmpty) |
| } |
| } |
| } |
| |
| test("SPARK-42101: Apply AQE if contains nested AdaptiveSparkPlanExec") { |
| withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") { |
| val df = spark.range(3).repartition().cache() |
| assert(df.sortWithinPartitions("id") |
| .queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) |
| } |
| } |
| |
| test("SPARK-42101: Make AQE support InMemoryTableScanExec") { |
| withSQLConf( |
| SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| val df1 = spark.range(10).selectExpr("cast(id as string) c1") |
| val df2 = spark.range(10).selectExpr("cast(id as string) c2") |
| val cached = df1.join(df2, $"c1" === $"c2").cache() |
| |
| def checkShuffleAndSort(firstAccess: Boolean): Unit = { |
| val df = cached.groupBy("c1").agg(max($"c2")) |
| val initialExecutedPlan = df.queryExecution.executedPlan |
| assert(collect(initialExecutedPlan) { |
| case s: ShuffleExchangeLike => s |
| }.size == (if (firstAccess) 1 else 0)) |
| assert(collect(initialExecutedPlan) { |
| case s: SortExec => s |
| }.size == (if (firstAccess) 2 else 0)) |
| assert(collect(initialExecutedPlan) { |
| case i: InMemoryTableScanLike => i |
| }.head.isMaterialized != firstAccess) |
| |
| df.collect() |
| val finalExecutedPlan = df.queryExecution.executedPlan |
| assert(collect(finalExecutedPlan) { |
| case s: ShuffleExchangeLike => s |
| }.isEmpty) |
| assert(collect(finalExecutedPlan) { |
| case s: SortExec => s |
| }.isEmpty) |
| assert(collect(initialExecutedPlan) { |
| case i: InMemoryTableScanLike => i |
| }.head.isMaterialized) |
| } |
| |
| // first access cache |
| checkShuffleAndSort(firstAccess = true) |
| |
| // access a materialized cache |
| checkShuffleAndSort(firstAccess = false) |
| } |
| } |
| |
| test("SPARK-42101: Do not coalesce shuffle partition if other side is TableCacheQueryStage") { |
| withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", |
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { |
| withTempView("v1", "v2") { |
| Seq(1, 2).toDF("c1").repartition(3, $"c1").cache().createOrReplaceTempView("v1") |
| Seq(1, 2).toDF("c2").createOrReplaceTempView("v2") |
| |
| val df = spark.sql("SELECT * FROM v1 JOIN v2 ON v1.c1 = v2.c2") |
| df.collect() |
| val finalPlan = df.queryExecution.executedPlan |
| assert(collect(finalPlan) { |
| case q: ShuffleQueryStageExec => q |
| }.size == 1) |
| assert(collect(finalPlan) { |
| case r: AQEShuffleReadExec => r |
| }.isEmpty) |
| } |
| } |
| } |
| |
| test("SPARK-42101: Coalesce shuffle partition with union even if exists TableCacheQueryStage") { |
| withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { |
| val cached = Seq(1).toDF("c").cache() |
| val df = Seq(2).toDF("c").repartition($"c").unionAll(cached) |
| df.collect() |
| assert(collect(df.queryExecution.executedPlan) { |
| case r @ AQEShuffleReadExec(_: ShuffleQueryStageExec, _) => r |
| }.size == 1) |
| assert(collect(df.queryExecution.executedPlan) { |
| case c: TableCacheQueryStageExec => c |
| }.size == 1) |
| } |
| } |
| |
| test("SPARK-43026: Apply AQE with non-exchange table cache") { |
| Seq(true, false).foreach { canChangeOP => |
| withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> canChangeOP.toString) { |
| // No exchange, no need for AQE |
| val df = spark.range(0).cache() |
| df.collect() |
| assert(!df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) |
| // Has exchange, apply AQE |
| val df2 = spark.range(0).repartition(1).cache() |
| df2.collect() |
| assert(df2.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) |
| } |
| } |
| } |
| |
| test("SPARK-43376: Improve reuse subquery with table cache") { |
| withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") { |
| withTable("t1", "t2") { |
| withCache("t1") { |
| Seq(1).toDF("c1").cache().createOrReplaceTempView("t1") |
| Seq(2).toDF("c2").createOrReplaceTempView("t2") |
| |
| val (_, adaptive) = runAdaptiveAndVerifyResult( |
| "SELECT * FROM t1 WHERE c1 < (SELECT c2 FROM t2)") |
| assert(findReusedSubquery(adaptive).size == 1) |
| } |
| } |
| } |
| } |
| |
| test("SPARK-44040: Fix compute stats when AggregateExec nodes above QueryStageExec") { |
| val emptyDf = spark.range(1).where("false") |
| val aggDf1 = emptyDf.agg(sum("id").as("id")).withColumn("name", lit("df1")) |
| val aggDf2 = emptyDf.agg(sum("id").as("id")).withColumn("name", lit("df2")) |
| val unionDF = aggDf1.union(aggDf2) |
| checkAnswer(unionDF.select("id").distinct(), Seq(Row(null))) |
| } |
| |
| test("SPARK-47247: coalesce differently for BNLJ") { |
| Seq(true, false).foreach { expectCoalesce => |
| val minPartitionSize = if (expectCoalesce) "64MB" else "1B" |
| withSQLConf( |
| SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "64MB", |
| SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE.key -> minPartitionSize) { |
| val (_, adaptivePlan) = runAdaptiveAndVerifyResult( |
| "SELECT /*+ broadcast(testData2) */ * " + |
| "FROM (SELECT value v, max(key) k from testData group by value) " + |
| "JOIN testData2 ON k + a > 0") |
| val bnlj = findTopLevelBroadcastNestedLoopJoin(adaptivePlan) |
| assert(bnlj.size == 1) |
| val coalescedReads = collect(adaptivePlan) { |
| case read: AQEShuffleReadExec if read.isCoalescedRead => read |
| } |
| assert(coalescedReads.nonEmpty == expectCoalesce) |
| } |
| } |
| } |
| } |
| |
| /** |
| * Invalid implementation class for [[CostEvaluator]]. |
| */ |
| private class InvalidCostEvaluator() {} |
| |
| /** |
| * A simple [[CostEvaluator]] to count number of [[ShuffleExchangeLike]] and [[SortExec]]. |
| */ |
| private case class SimpleShuffleSortCostEvaluator() extends CostEvaluator { |
| override def evaluateCost(plan: SparkPlan): Cost = { |
| val cost = plan.collect { |
| case s: ShuffleExchangeLike => s |
| case s: SortExec => s |
| }.size |
| SimpleCost(cost) |
| } |
| } |
| |
| /** |
| * Helps to simulate ExchangeQueryStageExec materialization failure. |
| */ |
| private object TestProblematicCoalesceStrategy extends Strategy { |
| private case class TestProblematicCoalesceExec(numPartitions: Int, child: SparkPlan) |
| extends UnaryExecNode { |
| override protected def doExecute(): RDD[InternalRow] = |
| throw new SparkException("ProblematicCoalesce execution is failed") |
| override def output: Seq[Attribute] = child.output |
| override protected def withNewChildInternal(newChild: SparkPlan): TestProblematicCoalesceExec = |
| copy(child = newChild) |
| } |
| |
| override def apply(plan: LogicalPlan): Seq[SparkPlan] = { |
| plan match { |
| case org.apache.spark.sql.catalyst.plans.logical.Repartition( |
| numPartitions, false, child) => |
| TestProblematicCoalesceExec(numPartitions, planLater(child)) :: Nil |
| case _ => Nil |
| } |
| } |
| } |