| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.table.plan.util |
| |
| import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty |
| import org.apache.flink.table.plan.nodes.calcite.{Expand, LogicalRank, LogicalWindowAggregate, Rank} |
| import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalRank, FlinkLogicalWindowAggregate} |
| import org.apache.flink.table.plan.nodes.physical.batch._ |
| import org.apache.flink.table.plan.nodes.physical.stream.StreamExecRank |
| import org.apache.flink.table.plan.util.FlinkRelOptUtil.{checkAndGetFullGroupSet, checkAndSplitAggCalls} |
| |
| import org.apache.calcite.avatica.util.TimeUnitRange._ |
| import org.apache.calcite.plan.RelOptUtil |
| import org.apache.calcite.rel.core._ |
| import org.apache.calcite.rel.metadata.{RelMdUtil, RelMetadataQuery} |
| import org.apache.calcite.rel.{RelNode, SingleRel} |
| import org.apache.calcite.rex._ |
| import org.apache.calcite.sql.SqlKind |
| import org.apache.calcite.sql.`type`.SqlTypeName._ |
| import org.apache.calcite.util.{ImmutableBitSet, NumberUtil} |
| import com.google.common.collect.ImmutableList |
| |
| import java.lang.Double |
| import java.math.BigDecimal |
| import java.util |
| |
| import scala.collection.JavaConversions._ |
| import scala.collection.mutable |
| |
| /** |
| * FlinkRelMdUtil provides utility methods used by the metadata provider methods. |
| */ |
| object FlinkRelMdUtil { |
| |
| /** |
| * Creates a RexNode that stores a selectivity value corresponding to the |
| * selectivity of a semi-join/anti-join. This can be added to a filter to simulate the |
| * effect of the semi-join/anti-join during costing, but should never appear in a real |
| * plan since it has no physical implementation. |
| * |
| * @param mq instance of metadata query |
| * @param rel the semiJoin or antiJoin of interest |
| * @return constructed rexNode |
| */ |
| def makeSemiJoinSelectivityRexNode( |
| mq: RelMetadataQuery, |
| rel: SemiJoin): RexNode = { |
| val joinInfo = rel.analyzeCondition() |
| val rexBuilder = rel.getCluster.getRexBuilder |
| makeSemiJoinSelectivityRexNode(mq, joinInfo, rel.getLeft, rel.getRight, rel.isAnti, rexBuilder) |
| } |
| |
| private def makeSemiJoinSelectivityRexNode( |
| mq: RelMetadataQuery, |
| joinInfo: JoinInfo, |
| left: RelNode, |
| right: RelNode, |
| isAnti: Boolean, |
| rexBuilder: RexBuilder): RexNode = { |
| |
| val equiSelectivity: Double = if (!joinInfo.leftKeys.isEmpty) { |
| RelMdUtil.computeSemiJoinSelectivity(mq, left, right, joinInfo.leftKeys, joinInfo.rightKeys) |
| } else { |
| 1D |
| } |
| |
| val nonEquiSelectivity = RelMdUtil.guessSelectivity(joinInfo.getRemaining(rexBuilder)) |
| val semiJoinSelectivity = equiSelectivity * nonEquiSelectivity |
| |
| val selectivity = if (isAnti) { |
| val antiJoinSelectivity = 1.0 - semiJoinSelectivity |
| if (antiJoinSelectivity == 0.0) { |
| // we don't expect that anti-join's selectivity is 0.0, so choose a default value 0.1 |
| 0.1 |
| } else { |
| antiJoinSelectivity |
| } |
| } else { |
| semiJoinSelectivity |
| } |
| |
| rexBuilder.makeCall( |
| RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC, |
| rexBuilder.makeApproxLiteral(new BigDecimal(selectivity))) |
| } |
| |
| /** |
| * Creates a RexNode that stores a selectivity value corresponding to the |
| * selectivity of a NamedProperties predicate. |
| * |
| * @param winAgg window aggregate node |
| * @param predicate a RexNode |
| * @return constructed rexNode including non-NamedProperties predicates and |
| * a predicate that stores NamedProperties predicate's selectivity |
| */ |
| def makeNamePropertiesSelectivityRexNode( |
| winAgg: LogicalWindowAggregate, |
| predicate: RexNode): RexNode = { |
| val fullGroupSet = checkAndGetFullGroupSet(winAgg) |
| makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate) |
| } |
| |
| /** |
| * Creates a RexNode that stores a selectivity value corresponding to the |
| * selectivity of a NamedProperties predicate. |
| * |
| * @param winAgg window aggregate node |
| * @param predicate a RexNode |
| * @return constructed rexNode including non-NamedProperties predicates and |
| * a predicate that stores NamedProperties predicate's selectivity |
| */ |
| def makeNamePropertiesSelectivityRexNode( |
| winAgg: FlinkLogicalWindowAggregate, |
| predicate: RexNode): RexNode = { |
| val fullGroupSet = checkAndGetFullGroupSet(winAgg) |
| makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate) |
| } |
| |
| /** |
| * Creates a RexNode that stores a selectivity value corresponding to the |
| * selectivity of a NamedProperties predicate. |
| * |
| * @param globalWinAgg global window aggregate node |
| * @param predicate a RexNode |
| * @return constructed rexNode including non-NamedProperties predicates and |
| * a predicate that stores NamedProperties predicate's selectivity |
| */ |
| def makeNamePropertiesSelectivityRexNode( |
| globalWinAgg: BatchExecWindowAggregateBase, |
| predicate: RexNode): RexNode = { |
| require(globalWinAgg.isFinal, "local window agg does not contain NamedProperties!") |
| val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping |
| makeNamePropertiesSelectivityRexNode( |
| globalWinAgg, fullGrouping, globalWinAgg.getNamedProperties, predicate) |
| } |
| |
| /** |
| * Creates a RexNode that stores a selectivity value corresponding to the |
| * selectivity of a NamedProperties predicate. |
| * |
| * @param winAgg window aggregate node |
| * @param fullGrouping full groupSets |
| * @param namedProperties NamedWindowProperty list |
| * @param predicate a RexNode |
| * @return constructed rexNode including non-NamedProperties predicates and |
| * a predicate that stores NamedProperties predicate's selectivity |
| */ |
| def makeNamePropertiesSelectivityRexNode( |
| winAgg: SingleRel, |
| fullGrouping: Array[Int], |
| namedProperties: Seq[NamedWindowProperty], |
| predicate: RexNode): RexNode = { |
| if (predicate == null || predicate.isAlwaysTrue || namedProperties.isEmpty) { |
| return predicate |
| } |
| val rexBuilder = winAgg.getCluster.getRexBuilder |
| val namePropertiesStartIdx = winAgg.getRowType.getFieldCount - namedProperties.size |
| // split non-nameProperties predicates and nameProperties predicates |
| val pushable = new util.ArrayList[RexNode] |
| val notPushable = new util.ArrayList[RexNode] |
| RelOptUtil.splitFilters( |
| ImmutableBitSet.range(0, namePropertiesStartIdx), |
| predicate, |
| pushable, |
| notPushable) |
| if (notPushable.nonEmpty) { |
| val pred = RexUtil.composeConjunction(rexBuilder, notPushable, true) |
| val selectivity = RelMdUtil.guessSelectivity(pred) |
| val fun = rexBuilder.makeCall( |
| RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC, |
| rexBuilder.makeApproxLiteral(new BigDecimal(selectivity))) |
| pushable.add(fun) |
| } |
| RexUtil.composeConjunction(rexBuilder, pushable, true) |
| } |
| |
| /** |
| * Computes the cardinality of a particular expression from the projection |
| * list. |
| * |
| * @param mq metadata query instance |
| * @param calc calc RelNode |
| * @param expr projection expression |
| * @return cardinality |
| */ |
| def cardOfCalcExpr(mq: RelMetadataQuery, calc: Calc, expr: RexNode): Double = { |
| expr.accept(new CardOfCalcExpr(mq, calc)) |
| } |
| |
| /** |
| * Visitor that walks over a scalar expression and computes the |
| * cardinality of its result. |
| * The code is borrowed from RelMdUtil |
| * |
| * @param mq metadata query instance |
| * @param calc calc relnode |
| */ |
| private class CardOfCalcExpr( |
| mq: RelMetadataQuery, |
| calc: Calc) |
| extends RexVisitorImpl[Double](true) { |
| private val program = calc.getProgram |
| |
| private val condition = if (program.getCondition != null) { |
| program.expandLocalRef(program.getCondition) |
| } else { |
| null |
| } |
| |
| override def visitInputRef(inputRef: RexInputRef): Double = { |
| val col = ImmutableBitSet.of(inputRef.getIndex) |
| val distinctRowCount = mq.getDistinctRowCount(calc.getInput, col, condition) |
| if (distinctRowCount == null) { |
| null |
| } else { |
| RelMdUtil.numDistinctVals(distinctRowCount, mq.getAverageRowSize(calc)) |
| } |
| } |
| |
| override def visitLiteral(literal: RexLiteral): Double = { |
| RelMdUtil.numDistinctVals(1D, mq.getAverageRowSize(calc)) |
| } |
| |
| override def visitCall(call: RexCall): Double = { |
| val rowCount = mq.getRowCount(calc) |
| val distinctRowCount: Double = if (call.isA(SqlKind.MINUS_PREFIX)) { |
| cardOfCalcExpr(mq, calc, call.getOperands.get(0)) |
| } else if (call.isA(ImmutableList.of(SqlKind.PLUS, SqlKind.MINUS))) { |
| val card0 = cardOfCalcExpr(mq, calc, call.getOperands.get(0)) |
| if (card0 == null) { |
| null |
| } else { |
| val card1 = cardOfCalcExpr(mq, calc, call.getOperands.get(1)) |
| if (card1 == null) { |
| null |
| } else { |
| Math.max(card0, card1) |
| } |
| } |
| } else if (call.isA(ImmutableList.of(SqlKind.TIMES, SqlKind.DIVIDE))) { |
| NumberUtil.multiply( |
| cardOfCalcExpr(mq, calc, call.getOperands.get(0)), |
| cardOfCalcExpr(mq, calc, call.getOperands.get(1))) |
| } else if (call.isA(SqlKind.EXTRACT)) { |
| val extractUnit = call.getOperands.get(0) |
| val timeOperand = call.getOperands.get(1) |
| extractUnit match { |
| // go https://www.postgresql.org/docs/9.1/static/ |
| // functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT to get the definitions of timeunits |
| case unit: RexLiteral => |
| val unitValue = unit.getValue |
| val timeOperandType = timeOperand.getType.getSqlTypeName |
| // assume min time is 1970-01-01 00:00:00, max time is 2100-12-31 21:59:59 |
| unitValue match { |
| case YEAR => 130D // [1970, 2100] |
| case MONTH => 12D |
| case DAY => 31D |
| case HOUR => 24D |
| case MINUTE => 60D |
| case SECOND => timeOperandType match { |
| case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999] |
| case _ => 60D // [0, 59] |
| } |
| case QUARTER => 4D |
| case WEEK => 53D // [1, 53] |
| case MILLISECOND => timeOperandType match { |
| case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999] |
| case _ => 60D // [0, 59] |
| } |
| case MICROSECOND => timeOperandType match { |
| case TIMESTAMP | TIME => 60 * 1000D * 1000D // [0.000, 59.999] |
| case _ => 60D // [0, 59] |
| } |
| case DOW => 7D // [0, 6] |
| case DOY => 366D // [1, 366] |
| case EPOCH => timeOperandType match { |
| // the number of seconds since 1970-01-01 00:00:00 UTC |
| case TIMESTAMP | TIME => 130 * 24 * 60 * 60 * 1000D |
| case _ => 130 * 24 * 60 * 60D |
| } |
| case DECADE => 13D // The year field divided by 10 |
| case CENTURY => 2D |
| case MILLENNIUM => 2D |
| case _ => cardOfCalcExpr(mq, calc, timeOperand) |
| } |
| case _ => cardOfCalcExpr(mq, calc, timeOperand) |
| } |
| } else if (call.getOperands.size() == 1) { |
| cardOfCalcExpr(mq, calc, call.getOperands.get(0)) |
| } else { |
| if (rowCount != null) rowCount / 10 else null |
| } |
| if (distinctRowCount == null) { |
| null |
| } else { |
| RelMdUtil.numDistinctVals(distinctRowCount, rowCount) |
| } |
| } |
| |
| } |
| |
| /** |
| * Takes a bitmap representing a set of input references and extracts the |
| * ones that reference the group by columns in an aggregate. |
| * |
| * |
| * @param groupKey the original bitmap |
| * @param aggRel the aggregate |
| */ |
| def setAggChildKeys( |
| groupKey: ImmutableBitSet, |
| aggRel: Aggregate): (ImmutableBitSet, Array[AggregateCall]) = { |
| val childKeyBuilder = ImmutableBitSet.builder |
| val aggCalls = new mutable.ArrayBuffer[AggregateCall]() |
| val groupSet = aggRel.getGroupSet.toArray |
| val (auxGroupSet, otherAggCalls) = checkAndSplitAggCalls(aggRel) |
| val fullGroupSet = groupSet ++ auxGroupSet |
| // does not need to take keys in aggregate call into consideration if groupKey contains all |
| // groupSet element in aggregate |
| val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get) |
| groupKey.foreach( |
| bit => |
| if (bit < fullGroupSet.length) { |
| childKeyBuilder.set(fullGroupSet(bit)) |
| } else if (!containsAllAggGroupKeys) { |
| // getIndicatorCount return 0 if auxGroupSet is not empty |
| val agg = otherAggCalls.get(bit - (fullGroupSet.length + aggRel.getIndicatorCount)) |
| aggCalls += agg |
| } |
| ) |
| (childKeyBuilder.build(), aggCalls.toArray) |
| } |
| |
| /** |
| * Takes a bitmap representing a set of input references and extracts the |
| * ones that reference the group by columns in an aggregate. |
| * |
| * @param groupKey the original bitmap |
| * @param aggRel the aggregate |
| */ |
| def setAggChildKeys( |
| groupKey: ImmutableBitSet, |
| aggRel: BatchExecGroupAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = { |
| require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!") |
| setChildKeysOfAgg(groupKey, aggRel) |
| } |
| |
| /** |
| * Takes a bitmap representing a set of input references and extracts the |
| * ones that reference the group by columns in an aggregate. |
| * |
| * @param groupKey the original bitmap |
| * @param aggRel the aggregate |
| */ |
| def setAggChildKeys( |
| groupKey: ImmutableBitSet, |
| aggRel: BatchExecWindowAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = { |
| require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!") |
| setChildKeysOfAgg(groupKey, aggRel) |
| } |
| |
| private def setChildKeysOfAgg( |
| groupKey: ImmutableBitSet, |
| aggRel: SingleRel): (ImmutableBitSet, Array[AggregateCall]) = { |
| val (aggCalls, fullGroupSet) = aggRel match { |
| case agg: BatchExecLocalSortWindowAggregate => |
| // grouping + assignTs + auxGrouping |
| (agg.getAggCallList, |
| agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping) |
| case agg: BatchExecLocalHashWindowAggregate => |
| // grouping + assignTs + auxGrouping |
| (agg.getAggCallList, |
| agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping) |
| case agg: BatchExecWindowAggregateBase => |
| (agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping) |
| case agg: BatchExecGroupAggregateBase => |
| (agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping) |
| case _ => throw new IllegalArgumentException(s"Unknown relnode type ${aggRel.getRelTypeName}") |
| } |
| // does not need to take keys in aggregate call into consideration if groupKey contains all |
| // groupSet element in aggregate |
| val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get) |
| val childKeyBuilder = ImmutableBitSet.builder |
| val aggs = new mutable.ArrayBuffer[AggregateCall]() |
| groupKey.foreach( |
| bit => |
| if (bit < fullGroupSet.length) { |
| childKeyBuilder.set(fullGroupSet(bit)) |
| } else if (!containsAllAggGroupKeys) { |
| val agg = aggCalls.get(bit - fullGroupSet.length) |
| aggs += agg |
| } |
| ) |
| (childKeyBuilder.build(), aggs.toArray) |
| } |
| |
| /** |
| * Takes a bitmap representing a set of local window aggregate references. |
| * |
| * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties |
| * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls |
| * |
| * Skips `assignTs` when mapping `groupKey` to `childKey`. |
| * |
| * @param groupKey the original bitmap |
| * @param globalWinAgg the global window aggregate |
| */ |
| def setChildKeysOfWinAgg( |
| groupKey: ImmutableBitSet, |
| globalWinAgg: BatchExecWindowAggregateBase): ImmutableBitSet = { |
| require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local agg!") |
| val childKeyBuilder = ImmutableBitSet.builder |
| groupKey.toArray.foreach { key => |
| if (key < globalWinAgg.getGrouping.length) { |
| childKeyBuilder.set(key) |
| } else { |
| // skips `assignTs` |
| childKeyBuilder.set(key + 1) |
| } |
| } |
| childKeyBuilder.build() |
| } |
| |
| /** |
| * Split groupKeys on Agregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase |
| * into keys on aggregate's groupKey and aggregate's aggregateCalls. |
| * |
| * @param agg the aggregate |
| * @param groupKey the original bitmap |
| */ |
| private[flink] def splitGroupKeysOnAggregate( |
| agg: SingleRel, |
| groupKey: ImmutableBitSet): (ImmutableBitSet, Array[AggregateCall]) = { |
| |
| def removeAuxKey( |
| groupKey: ImmutableBitSet, |
| groupSet: Array[Int], |
| auxGroupSet: Array[Int]): ImmutableBitSet = { |
| if (groupKey.contains(ImmutableBitSet.of(groupSet: _*))) { |
| // remove auxGroupSet from groupKey if groupKey contain both full-groupSet |
| // and (partial-)auxGroupSet |
| groupKey.except(ImmutableBitSet.of(auxGroupSet: _*)) |
| } else { |
| groupKey |
| } |
| } |
| |
| agg match { |
| case rel: Aggregate => |
| val (auxGroupSet, _) = FlinkRelOptUtil.checkAndSplitAggCalls(rel) |
| val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel) |
| val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGroupSet.toArray, auxGroupSet) |
| (childKeyExcludeAuxKey, aggCalls) |
| case rel: BatchExecGroupAggregateBase => |
| // set the bits as they correspond to the child input |
| val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel) |
| val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping) |
| (childKeyExcludeAuxKey, aggCalls) |
| case rel: BatchExecWindowAggregateBase => |
| val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel) |
| val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping) |
| (childKeyExcludeAuxKey, aggCalls) |
| case _ => throw new IllegalArgumentException( |
| s"Unknown aggregate type: ${agg.getRelTypeName}.") |
| } |
| } |
| |
| /** |
| * Shifts every [[RexInputRef]] in an expression higher than length of full grouping |
| * (for skips `assignTs`). |
| * |
| * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties |
| * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls |
| * |
| * @param predicate a RexNode |
| * @param globalWinAgg the global window aggregate |
| */ |
| def setChildPredicateOfWinAgg( |
| predicate: RexNode, |
| globalWinAgg: BatchExecWindowAggregateBase): RexNode = { |
| require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local agg!") |
| if (predicate == null) { |
| return null |
| } |
| // grouping + assignTs + auxGrouping |
| val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping |
| // skips `assignTs` |
| RexUtil.shift(predicate, fullGrouping.length, 1) |
| } |
| |
| /** |
| * Split a predicate on Aggregate into two parts, the first one is pushable part, |
| * the second one is rest part. |
| * |
| * @param agg Aggregate which to analyze |
| * @param predicate Predicate which to analyze |
| * @return a tuple, first element is pushable part, second element is rest part. |
| * Note, pushable condition will be converted based on the input field position. |
| */ |
| def splitPredicateOnAggregate( |
| agg: Aggregate, |
| predicate: RexNode): (Option[RexNode], Option[RexNode]) = { |
| val fullGroupSet = checkAndGetFullGroupSet(agg) |
| splitPredicateOnAgg(fullGroupSet, agg, predicate) |
| } |
| |
| /** |
| * Split a predicate on BatchExecGroupAggregateBase into two parts, |
| * the first one is pushable part, the second one is rest part. |
| * |
| * @param agg Aggregate which to analyze |
| * @param predicate Predicate which to analyze |
| * @return a tuple, first element is pushable part, second element is rest part. |
| * Note, pushable condition will be converted based on the input field position. |
| */ |
| def splitPredicateOnAggregate( |
| agg: BatchExecGroupAggregateBase, |
| predicate: RexNode): (Option[RexNode], Option[RexNode]) = { |
| splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate) |
| } |
| |
| /** |
| * Split a predicate on WindowAggregateBatchExecBase into two parts, |
| * the first one is pushable part, the second one is rest part. |
| * |
| * @param agg Aggregate which to analyze |
| * @param predicate Predicate which to analyze |
| * @return a tuple, first element is pushable part, second element is rest part. |
| * Note, pushable condition will be converted based on the input field position. |
| */ |
| def splitPredicateOnAggregate( |
| agg: BatchExecWindowAggregateBase, |
| predicate: RexNode): (Option[RexNode], Option[RexNode]) = { |
| splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate) |
| } |
| |
| private def splitPredicateOnAgg( |
| grouping: Array[Int], |
| agg: SingleRel, |
| predicate: RexNode): (Option[RexNode], Option[RexNode]) = { |
| val notPushable = new util.ArrayList[RexNode] |
| val pushable = new util.ArrayList[RexNode] |
| val numOfGroupKey = grouping.length |
| RelOptUtil.splitFilters( |
| ImmutableBitSet.range(0, numOfGroupKey), |
| predicate, |
| pushable, |
| notPushable) |
| val rexBuilder = agg.getCluster.getRexBuilder |
| val childPred = if (pushable.isEmpty) { |
| None |
| } else { |
| // Converts a list of expressions that are based on the output fields of a |
| // Aggregate to equivalent expressions on the Aggregate's input fields. |
| val aggOutputFields = agg.getRowType.getFieldList |
| val aggInputFields = agg.getInput.getRowType.getFieldList |
| val adjustments = new Array[Int](aggOutputFields.size) |
| grouping.zipWithIndex foreach { |
| case (bit, index) => adjustments(index) = bit - index |
| } |
| val pushableConditions = pushable map { |
| pushCondition => |
| pushCondition.accept( |
| new RelOptUtil.RexInputConverter( |
| rexBuilder, |
| aggOutputFields, |
| aggInputFields, |
| adjustments)) |
| } |
| Option(RexUtil.composeConjunction(rexBuilder, pushableConditions, true)) |
| } |
| val restPred = if (notPushable.isEmpty) { |
| None |
| } else { |
| Option(RexUtil.composeConjunction(rexBuilder, notPushable, true)) |
| } |
| (childPred, restPred) |
| } |
| |
| def getRankFunColumnIndex(rank: Rank): Int = { |
| rank match { |
| case r: LogicalRank => getRankFunColumnIndex(rank, outputRankFunColumn = true) |
| case r: FlinkLogicalRank => getRankFunColumnIndex(rank, r.outputRankFunColumn) |
| case r: BatchExecRank => getRankFunColumnIndex(rank, r.outputRankFunColumn) |
| case r: StreamExecRank => getRankFunColumnIndex(rank, r.outputRankFunColumn) |
| } |
| } |
| |
| private def getRankFunColumnIndex(rank: Rank, outputRankFunColumn: Boolean): Int = { |
| if (outputRankFunColumn) { |
| require(rank.getRowType.getFieldCount == rank.getInput.getRowType.getFieldCount + 1) |
| rank.getRowType.getFieldCount - 1 |
| } else { |
| require(rank.getRowType.getFieldCount == rank.getInput.getRowType.getFieldCount) |
| -1 |
| } |
| } |
| |
| def splitPredicateOnRank( |
| rank: Rank, |
| predicate: RexNode): (Option[RexNode], Option[RexNode]) = { |
| val rankFunColumnIndex = getRankFunColumnIndex(rank) |
| if (predicate == null || predicate.isAlwaysTrue || rankFunColumnIndex < 0) { |
| return (Some(predicate), None) |
| } |
| |
| val rankNodes = new util.ArrayList[RexNode] |
| val nonRankNodes = new util.ArrayList[RexNode] |
| RelOptUtil.splitFilters( |
| ImmutableBitSet.range(0, rankFunColumnIndex), |
| predicate, |
| nonRankNodes, |
| rankNodes) |
| val rexBuilder = rank.getCluster.getRexBuilder |
| val nonRankPred = if (nonRankNodes.isEmpty) { |
| None |
| } else { |
| Option(RexUtil.composeConjunction(rexBuilder, nonRankNodes, true)) |
| } |
| val rankPred = if (rankNodes.isEmpty) { |
| None |
| } else { |
| Option(RexUtil.composeConjunction(rexBuilder, rankNodes, true)) |
| } |
| (nonRankPred, rankPred) |
| } |
| |
| def getRankRangeNdv(rankRange: RankRange): Double = rankRange match { |
| case r: ConstantRankRange => (r.rankEnd - r.rankStart + 1).toDouble |
| case _ => 100D // default value now |
| } |
| |
| /** Splits a column set between left and right sets. */ |
| def splitColumnsIntoLeftAndRight( |
| leftCount: Int, |
| columns: ImmutableBitSet): (ImmutableBitSet, ImmutableBitSet) = { |
| val leftBuilder = ImmutableBitSet.builder |
| val rightBuilder = ImmutableBitSet.builder |
| columns.foreach { |
| bit => if (bit < leftCount) leftBuilder.set(bit) else rightBuilder.set(bit - leftCount) |
| } |
| (leftBuilder.build, rightBuilder.build) |
| } |
| |
| |
| /** |
| * Estimates ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable. |
| * |
| * the value of `1.0 - math.exp(-0.1 * groupCount)` increases with groupCount |
| * from 0.095 until close to 1.0. when groupCount is 1, the formula result is 0.095, |
| * when groupCount is 2, the formula result is 0.18, |
| * when groupCount is 3, the formula result is 0.25. |
| * ... |
| * |
| * @param groupingLength grouping keys length of aggregate |
| * @return the ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable. |
| */ |
| def getAggregationRatioIfNdvUnavailable(groupingLength: Int): Double = |
| 1.0 - math.exp(-0.1 * groupingLength) |
| |
| /** |
| * Estimates outputRowCount of local aggregate. |
| * |
| * output rowcount of local agg is (1 - pow((1 - 1/x) , n/m)) * m * x, based on two assumption: |
| * 1. even distribution of all distinct data |
| * 2. even distribution of all data in each concurrent local agg worker |
| * |
| * @param parallelism number of concurrent worker of local aggregate |
| * @param inputRowCount rowcount of input node of aggregate. |
| * @param globalAggRowCount rowcount of output of global aggregate. |
| * @return outputRowCount of local aggregate. |
| */ |
| def getRowCountOfLocalAgg( |
| parallelism: Int, |
| inputRowCount: Double, |
| globalAggRowCount: Double): Double = |
| Math.min((1 - math.pow(1 - 1.0 / parallelism, inputRowCount / globalAggRowCount)) |
| * globalAggRowCount * parallelism, inputRowCount) |
| |
| /** |
| * Estimates new distinctRowCount of currentNode after it applies a condition. |
| * The estimation based on one assumption: |
| * even distribution of all distinct data |
| * |
| * @param rowCount rowcount of node. |
| * @param distinctRowCount distinct rowcount of node. |
| * @param selectivity selectivity of condition expression. |
| * @return new distinctRowCount |
| */ |
| def adaptNdvBasedOnSelectivity( |
| rowCount: Double, |
| distinctRowCount: Double, |
| selectivity: Double): Double = { |
| val ndv = Math.min(distinctRowCount, rowCount) |
| (1 - Math.pow(1 - selectivity, rowCount / ndv)) * ndv |
| } |
| |
| /** |
| * Returns [[RexInputRef]] index set of projects corresponding to the given column index. |
| * The index will be set as -1 if the given column in project is not a [[RexInputRef]]. |
| */ |
| def getInputRefIndices(index: Int, expand: Expand): util.Set[Int] = { |
| val inputRefs = new util.HashSet[Int]() |
| for (project <- expand.projects) { |
| project.get(index) match { |
| case inputRef: RexInputRef => inputRefs.add(inputRef.getIndex) |
| case _ => inputRefs.add(-1) |
| } |
| } |
| inputRefs |
| } |
| |
| } |