| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.sql.catalyst.optimizer |
| |
| import org.apache.spark.sql.catalyst.expressions._ |
| import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate |
| import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys |
| import org.apache.spark.sql.catalyst.plans.logical._ |
| import org.apache.spark.sql.catalyst.rules.Rule |
| import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.types._ |
| |
| /** |
| * Insert a filter on one side of the join if the other side has a selective predicate. |
| * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something |
| * else in the future. |
| */ |
| object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { |
| |
| // Wraps `expr` with a hash function if its byte size is larger than an integer. |
| private def mayWrapWithHash(expr: Expression): Expression = { |
| if (expr.dataType.defaultSize > IntegerType.defaultSize) { |
| new Murmur3Hash(Seq(expr)) |
| } else { |
| expr |
| } |
| } |
| |
| private def injectFilter( |
| filterApplicationSideExp: Expression, |
| filterApplicationSidePlan: LogicalPlan, |
| filterCreationSideExp: Expression, |
| filterCreationSidePlan: LogicalPlan): LogicalPlan = { |
| require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) |
| if (conf.runtimeFilterBloomFilterEnabled) { |
| injectBloomFilter( |
| filterApplicationSideExp, |
| filterApplicationSidePlan, |
| filterCreationSideExp, |
| filterCreationSidePlan |
| ) |
| } else { |
| injectInSubqueryFilter( |
| filterApplicationSideExp, |
| filterApplicationSidePlan, |
| filterCreationSideExp, |
| filterCreationSidePlan |
| ) |
| } |
| } |
| |
| private def injectBloomFilter( |
| filterApplicationSideExp: Expression, |
| filterApplicationSidePlan: LogicalPlan, |
| filterCreationSideExp: Expression, |
| filterCreationSidePlan: LogicalPlan): LogicalPlan = { |
| // Skip if the filter creation side is too big |
| if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) { |
| return filterApplicationSidePlan |
| } |
| val rowCount = filterCreationSidePlan.stats.rowCount |
| val bloomFilterAgg = |
| if (rowCount.isDefined && rowCount.get.longValue > 0L) { |
| new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), |
| Literal(rowCount.get.longValue)) |
| } else { |
| new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) |
| } |
| |
| val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")() |
| val aggregate = |
| ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) |
| val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) |
| val filter = BloomFilterMightContain(bloomFilterSubquery, |
| new XxHash64(Seq(filterApplicationSideExp))) |
| Filter(filter, filterApplicationSidePlan) |
| } |
| |
| private def injectInSubqueryFilter( |
| filterApplicationSideExp: Expression, |
| filterApplicationSidePlan: LogicalPlan, |
| filterCreationSideExp: Expression, |
| filterCreationSidePlan: LogicalPlan): LogicalPlan = { |
| require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) |
| val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) |
| val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() |
| val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) |
| if (!canBroadcastBySize(aggregate, conf)) { |
| // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, |
| // i.e., the semi-join will be a shuffled join, which is not worthwhile. |
| return filterApplicationSidePlan |
| } |
| val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), |
| ListQuery(aggregate, childOutputs = aggregate.output)) |
| Filter(filter, filterApplicationSidePlan) |
| } |
| |
| /** |
| * Returns whether the plan is a simple filter over scan and the filter is likely selective |
| * Also check if the plan only has simple expressions (attribute reference, literals) so that we |
| * do not add a subquery that might have an expensive computation |
| */ |
| private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { |
| def isSelective( |
| p: LogicalPlan, |
| predicateReference: AttributeSet, |
| hasHitFilter: Boolean, |
| hasHitSelectiveFilter: Boolean): Boolean = p match { |
| case Project(projectList, child) => |
| if (hasHitFilter) { |
| // We need to make sure all expressions referenced by filter predicates are simple |
| // expressions. |
| val referencedExprs = projectList.filter(predicateReference.contains) |
| referencedExprs.forall(isSimpleExpression) && |
| isSelective( |
| child, |
| referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _), |
| hasHitFilter, |
| hasHitSelectiveFilter) |
| } else { |
| assert(predicateReference.isEmpty && !hasHitSelectiveFilter) |
| isSelective(child, predicateReference, hasHitFilter, hasHitSelectiveFilter) |
| } |
| case Filter(condition, child) => |
| isSimpleExpression(condition) && isSelective( |
| child, |
| predicateReference ++ condition.references, |
| hasHitFilter = true, |
| hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition)) |
| case _: LeafNode => hasHitSelectiveFilter |
| case _ => false |
| } |
| |
| !plan.isStreaming && |
| isSelective(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false) |
| } |
| |
| private def isSimpleExpression(e: Expression): Boolean = { |
| !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, |
| REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) |
| } |
| |
| private def isProbablyShuffleJoin(left: LogicalPlan, |
| right: LogicalPlan, hint: JoinHint): Boolean = { |
| !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && |
| !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) |
| } |
| |
| private def probablyHasShuffle(plan: LogicalPlan): Boolean = { |
| plan.exists { |
| case Join(left, right, _, _, hint) => isProbablyShuffleJoin(left, right, hint) |
| case _: Aggregate => true |
| case _: Window => true |
| case _ => false |
| } |
| } |
| |
| // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`. |
| private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = { |
| val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES) |
| filterApplicationSide.collect({ |
| case leaf: LeafNode => leaf |
| }).map(scan => { |
| // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid |
| // creating a Bloom filter when the filter application side is very small, so using 0 |
| // as the byte size when the actual size is unknown can avoid regression by applying BF |
| // on a small table. |
| if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes |
| }).max |
| } |
| |
| // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a |
| // Bloom filter; false otherwise. |
| private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = { |
| // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom |
| // filter to each small task might be more costly than scanning them itself. Thus, we use max |
| // rather than sum here. |
| val maxScanSize = maxScanByteSize(filterApplicationSide) |
| maxScanSize >= |
| conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD) |
| } |
| |
| /** |
| * Check that: |
| * - The filterApplicationSideJoinExp can be pushed down through joins, aggregates and windows |
| * (ie the expression references originate from a single leaf node) |
| * - The filter creation side has a selective predicate |
| * - The current join is a shuffle join or a broadcast join that has a shuffle below it |
| * - The max filterApplicationSide scan size is greater than a configurable threshold |
| */ |
| private def filteringHasBenefit( |
| filterApplicationSide: LogicalPlan, |
| filterCreationSide: LogicalPlan, |
| filterApplicationSideExp: Expression, |
| hint: JoinHint): Boolean = { |
| findExpressionAndTrackLineageDown(filterApplicationSideExp, |
| filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && |
| (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || |
| probablyHasShuffle(filterApplicationSide)) && |
| satisfyByteSizeRequirement(filterApplicationSide) |
| } |
| |
| def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, |
| rightKey: Expression): Boolean = { |
| if (conf.runtimeFilterBloomFilterEnabled) { |
| hasBloomFilter(left, right, leftKey, rightKey) |
| } else { |
| hasInSubquery(left, right, leftKey, rightKey) |
| } |
| } |
| |
| // This checks if there is already a DPP filter, as this rule is called just after DPP. |
| def hasDynamicPruningSubquery( |
| left: LogicalPlan, |
| right: LogicalPlan, |
| leftKey: Expression, |
| rightKey: Expression): Boolean = { |
| (left, right) match { |
| case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => |
| pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) |
| case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => |
| pruningKey.fastEquals(rightKey) || |
| hasDynamicPruningSubquery(left, plan, leftKey, rightKey) |
| case _ => false |
| } |
| } |
| |
| def hasBloomFilter( |
| left: LogicalPlan, |
| right: LogicalPlan, |
| leftKey: Expression, |
| rightKey: Expression): Boolean = { |
| findBloomFilterWithExp(left, leftKey) || findBloomFilterWithExp(right, rightKey) |
| } |
| |
| private def findBloomFilterWithExp(plan: LogicalPlan, key: Expression): Boolean = { |
| plan.exists { |
| case Filter(condition, _) => |
| splitConjunctivePredicates(condition).exists { |
| case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _)) |
| if valueExpression.fastEquals(key) => true |
| case _ => false |
| } |
| case _ => false |
| } |
| } |
| |
| def hasInSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, |
| rightKey: Expression): Boolean = { |
| (left, right) match { |
| case (Filter(InSubquery(Seq(key), |
| ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => |
| key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) |
| case (_, Filter(InSubquery(Seq(key), |
| ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => |
| key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) |
| case _ => false |
| } |
| } |
| |
| private def tryInjectRuntimeFilter(plan: LogicalPlan): LogicalPlan = { |
| var filterCounter = 0 |
| val numFilterThreshold = conf.getConf(SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD) |
| plan transformUp { |
| case join @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, hint) => |
| var newLeft = left |
| var newRight = right |
| (leftKeys, rightKeys).zipped.foreach((l, r) => { |
| // Check if: |
| // 1. There is already a DPP filter on the key |
| // 2. There is already a runtime filter (Bloom filter or IN subquery) on the key |
| // 3. The keys are simple cheap expressions |
| if (filterCounter < numFilterThreshold && |
| !hasDynamicPruningSubquery(left, right, l, r) && |
| !hasRuntimeFilter(newLeft, newRight, l, r) && |
| isSimpleExpression(l) && isSimpleExpression(r)) { |
| val oldLeft = newLeft |
| val oldRight = newRight |
| if (canPruneLeft(joinType) && filteringHasBenefit(left, right, l, hint)) { |
| newLeft = injectFilter(l, newLeft, r, right) |
| } |
| // Did we actually inject on the left? If not, try on the right |
| if (newLeft.fastEquals(oldLeft) && canPruneRight(joinType) && |
| filteringHasBenefit(right, left, r, hint)) { |
| newRight = injectFilter(r, newRight, l, left) |
| } |
| if (!newLeft.fastEquals(oldLeft) || !newRight.fastEquals(oldRight)) { |
| filterCounter = filterCounter + 1 |
| } |
| } |
| }) |
| join.withNewChildren(Seq(newLeft, newRight)) |
| } |
| } |
| |
| override def apply(plan: LogicalPlan): LogicalPlan = plan match { |
| case s: Subquery if s.correlated => plan |
| case _ if !conf.runtimeFilterSemiJoinReductionEnabled && |
| !conf.runtimeFilterBloomFilterEnabled => plan |
| case _ => |
| val newPlan = tryInjectRuntimeFilter(plan) |
| if (conf.runtimeFilterSemiJoinReductionEnabled && !plan.fastEquals(newPlan)) { |
| RewritePredicateSubquery(newPlan) |
| } else { |
| newPlan |
| } |
| } |
| |
| } |