add neverConvertJoinsWithPostCondition strategy
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala index 5bad830..4fbdedf 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala
@@ -29,8 +29,6 @@ import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.CodegenSupport import org.apache.spark.SparkEnv -import org.apache.spark.sql.blaze.BlazeConvertStrategy.isAlwaysConvert -import org.apache.spark.sql.blaze.BlazeConvertStrategy.isNeverConvert import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.UnionExec @@ -59,6 +57,30 @@ val preferNativeAggr: Boolean = SparkEnv.get.conf.getBoolean("spark.blaze.prefer.native.aggr", defaultValue = false) + val neverConvertSkewJoinEnabled: Boolean = + SparkEnv.get.conf + .getBoolean("spark.blaze.strategy.enable.neverConvertSkewJoin", defaultValue = true) + val neverConvertJoinsWithPostConditionEnabled: Boolean = + SparkEnv.get.conf.getBoolean( + "spark.blaze.strategy.enable.neverConvertJoinsWtihCondition", + defaultValue = true) + val alwaysConvertDirectSortMergeJoinEnabled: Boolean = + SparkEnv.get.conf.getBoolean( + "spark.blaze.strategy.enable.alwaysConvertDirectSortMergeJoin", + defaultValue = true) + val neverConvertContinuousCodegensEnabled: Boolean = + SparkEnv.get.conf.getBoolean( + "spark.blaze.strategy.enable.neverConvertContinuousCodegens", + defaultValue = true) + val neverConvertScanWithInconvertibleChildren: Boolean = + SparkEnv.get.conf.getBoolean( + "spark.blaze.strategy.enable.neverConvertScanWithInconvertibleChildren", + defaultValue = true) + val neverConvertAggregatesChildren: Boolean = + SparkEnv.get.conf.getBoolean( + "spark.blaze.strategy.enable.neverConvertAggregatesChildren", + defaultValue = true) + val idTag: TreeNodeTag[UUID] = TreeNodeTag("blaze.id") val convertibleTag: TreeNodeTag[Boolean] = TreeNodeTag("blaze.convertible") val convertStrategyTag: TreeNodeTag[ConvertStrategy] = TreeNodeTag("blaze.convert.strategy") @@ -80,10 +102,11 @@ // execute some special strategies neverConvertSkewJoin(exec) + neverConvertJoinsWithPostCondition(exec) alwaysConvertDirectSortMergeJoin(exec) neverConvertContinuousCodegens(exec) neverConvertScanWithInconvertibleChildren(exec) - neverConvertAggregate(exec) + neverConvertAggregatesChildren(exec) def hasMoreInconvertibleChildren(e: SparkPlan) = e.children.count(isNeverConvert) > e.children.count(isAlwaysConvert) @@ -139,6 +162,15 @@ } } + private def neverConvertJoinsWithPostCondition(exec: SparkPlan): Unit = { + exec.foreach { + case e: SortMergeJoinExec if e.condition.nonEmpty => + e.setTagValue(convertStrategyTag, NeverConvert) + case e: BroadcastHashJoinExec if e.condition.nonEmpty => + e.setTagValue(convertStrategyTag, NeverConvert) + } + } + private def alwaysConvertDirectSortMergeJoin(exec: SparkPlan): Unit = { exec.foreach { case e: SortMergeJoinExec if !e.isSkewJoin && !isNeverConvert(e) => @@ -195,8 +227,8 @@ case _ => } } - private def neverConvertAggregate(exec: SparkPlan): Unit = { - val aggrAheadFlag = TreeNodeTag[Boolean](name = "blaze.aggregate") + private def neverConvertAggregatesChildren(exec: SparkPlan): Unit = { + val aggrAheadFlag = TreeNodeTag[Boolean](name = "blaze.aggregates.children") def needPutFlag(exec: SparkPlan): Boolean = { if (isNeverConvert(exec) || isAlwaysConvert(exec)) { return false