blob: 28b7f8d83930d758890b82effa553313e4b5a655 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.util
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.nodes.calcite.{Expand, LogicalRank, LogicalWindowAggregate, Rank}
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalRank, FlinkLogicalWindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecRank
import org.apache.flink.table.plan.util.FlinkRelOptUtil.{checkAndGetFullGroupSet, checkAndSplitAggCalls}
import org.apache.calcite.avatica.util.TimeUnitRange._
import org.apache.calcite.plan.RelOptUtil
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata.{RelMdUtil, RelMetadataQuery}
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.util.{ImmutableBitSet, NumberUtil}
import com.google.common.collect.ImmutableList
import java.lang.Double
import java.math.BigDecimal
import java.util
import scala.collection.JavaConversions._
import scala.collection.mutable
/**
* FlinkRelMdUtil provides utility methods used by the metadata provider methods.
*/
object FlinkRelMdUtil {
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a semi-join/anti-join. This can be added to a filter to simulate the
* effect of the semi-join/anti-join during costing, but should never appear in a real
* plan since it has no physical implementation.
*
* @param mq instance of metadata query
* @param rel the semiJoin or antiJoin of interest
* @return constructed rexNode
*/
def makeSemiJoinSelectivityRexNode(
mq: RelMetadataQuery,
rel: SemiJoin): RexNode = {
val joinInfo = rel.analyzeCondition()
val rexBuilder = rel.getCluster.getRexBuilder
makeSemiJoinSelectivityRexNode(mq, joinInfo, rel.getLeft, rel.getRight, rel.isAnti, rexBuilder)
}
private def makeSemiJoinSelectivityRexNode(
mq: RelMetadataQuery,
joinInfo: JoinInfo,
left: RelNode,
right: RelNode,
isAnti: Boolean,
rexBuilder: RexBuilder): RexNode = {
val equiSelectivity: Double = if (!joinInfo.leftKeys.isEmpty) {
RelMdUtil.computeSemiJoinSelectivity(mq, left, right, joinInfo.leftKeys, joinInfo.rightKeys)
} else {
1D
}
val nonEquiSelectivity = RelMdUtil.guessSelectivity(joinInfo.getRemaining(rexBuilder))
val semiJoinSelectivity = equiSelectivity * nonEquiSelectivity
val selectivity = if (isAnti) {
val antiJoinSelectivity = 1.0 - semiJoinSelectivity
if (antiJoinSelectivity == 0.0) {
// we don't expect that anti-join's selectivity is 0.0, so choose a default value 0.1
0.1
} else {
antiJoinSelectivity
}
} else {
semiJoinSelectivity
}
rexBuilder.makeCall(
RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
}
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a NamedProperties predicate.
*
* @param winAgg window aggregate node
* @param predicate a RexNode
* @return constructed rexNode including non-NamedProperties predicates and
* a predicate that stores NamedProperties predicate's selectivity
*/
def makeNamePropertiesSelectivityRexNode(
winAgg: LogicalWindowAggregate,
predicate: RexNode): RexNode = {
val fullGroupSet = checkAndGetFullGroupSet(winAgg)
makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate)
}
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a NamedProperties predicate.
*
* @param winAgg window aggregate node
* @param predicate a RexNode
* @return constructed rexNode including non-NamedProperties predicates and
* a predicate that stores NamedProperties predicate's selectivity
*/
def makeNamePropertiesSelectivityRexNode(
winAgg: FlinkLogicalWindowAggregate,
predicate: RexNode): RexNode = {
val fullGroupSet = checkAndGetFullGroupSet(winAgg)
makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate)
}
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a NamedProperties predicate.
*
* @param globalWinAgg global window aggregate node
* @param predicate a RexNode
* @return constructed rexNode including non-NamedProperties predicates and
* a predicate that stores NamedProperties predicate's selectivity
*/
def makeNamePropertiesSelectivityRexNode(
globalWinAgg: BatchExecWindowAggregateBase,
predicate: RexNode): RexNode = {
require(globalWinAgg.isFinal, "local window agg does not contain NamedProperties!")
val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
makeNamePropertiesSelectivityRexNode(
globalWinAgg, fullGrouping, globalWinAgg.getNamedProperties, predicate)
}
/**
* Creates a RexNode that stores a selectivity value corresponding to the
* selectivity of a NamedProperties predicate.
*
* @param winAgg window aggregate node
* @param fullGrouping full groupSets
* @param namedProperties NamedWindowProperty list
* @param predicate a RexNode
* @return constructed rexNode including non-NamedProperties predicates and
* a predicate that stores NamedProperties predicate's selectivity
*/
def makeNamePropertiesSelectivityRexNode(
winAgg: SingleRel,
fullGrouping: Array[Int],
namedProperties: Seq[NamedWindowProperty],
predicate: RexNode): RexNode = {
if (predicate == null || predicate.isAlwaysTrue || namedProperties.isEmpty) {
return predicate
}
val rexBuilder = winAgg.getCluster.getRexBuilder
val namePropertiesStartIdx = winAgg.getRowType.getFieldCount - namedProperties.size
// split non-nameProperties predicates and nameProperties predicates
val pushable = new util.ArrayList[RexNode]
val notPushable = new util.ArrayList[RexNode]
RelOptUtil.splitFilters(
ImmutableBitSet.range(0, namePropertiesStartIdx),
predicate,
pushable,
notPushable)
if (notPushable.nonEmpty) {
val pred = RexUtil.composeConjunction(rexBuilder, notPushable, true)
val selectivity = RelMdUtil.guessSelectivity(pred)
val fun = rexBuilder.makeCall(
RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
pushable.add(fun)
}
RexUtil.composeConjunction(rexBuilder, pushable, true)
}
/**
* Computes the cardinality of a particular expression from the projection
* list.
*
* @param mq metadata query instance
* @param calc calc RelNode
* @param expr projection expression
* @return cardinality
*/
def cardOfCalcExpr(mq: RelMetadataQuery, calc: Calc, expr: RexNode): Double = {
expr.accept(new CardOfCalcExpr(mq, calc))
}
/**
* Visitor that walks over a scalar expression and computes the
* cardinality of its result.
* The code is borrowed from RelMdUtil
*
* @param mq metadata query instance
* @param calc calc relnode
*/
private class CardOfCalcExpr(
mq: RelMetadataQuery,
calc: Calc)
extends RexVisitorImpl[Double](true) {
private val program = calc.getProgram
private val condition = if (program.getCondition != null) {
program.expandLocalRef(program.getCondition)
} else {
null
}
override def visitInputRef(inputRef: RexInputRef): Double = {
val col = ImmutableBitSet.of(inputRef.getIndex)
val distinctRowCount = mq.getDistinctRowCount(calc.getInput, col, condition)
if (distinctRowCount == null) {
null
} else {
RelMdUtil.numDistinctVals(distinctRowCount, mq.getAverageRowSize(calc))
}
}
override def visitLiteral(literal: RexLiteral): Double = {
RelMdUtil.numDistinctVals(1D, mq.getAverageRowSize(calc))
}
override def visitCall(call: RexCall): Double = {
val rowCount = mq.getRowCount(calc)
val distinctRowCount: Double = if (call.isA(SqlKind.MINUS_PREFIX)) {
cardOfCalcExpr(mq, calc, call.getOperands.get(0))
} else if (call.isA(ImmutableList.of(SqlKind.PLUS, SqlKind.MINUS))) {
val card0 = cardOfCalcExpr(mq, calc, call.getOperands.get(0))
if (card0 == null) {
null
} else {
val card1 = cardOfCalcExpr(mq, calc, call.getOperands.get(1))
if (card1 == null) {
null
} else {
Math.max(card0, card1)
}
}
} else if (call.isA(ImmutableList.of(SqlKind.TIMES, SqlKind.DIVIDE))) {
NumberUtil.multiply(
cardOfCalcExpr(mq, calc, call.getOperands.get(0)),
cardOfCalcExpr(mq, calc, call.getOperands.get(1)))
} else if (call.isA(SqlKind.EXTRACT)) {
val extractUnit = call.getOperands.get(0)
val timeOperand = call.getOperands.get(1)
extractUnit match {
// go https://www.postgresql.org/docs/9.1/static/
// functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT to get the definitions of timeunits
case unit: RexLiteral =>
val unitValue = unit.getValue
val timeOperandType = timeOperand.getType.getSqlTypeName
// assume min time is 1970-01-01 00:00:00, max time is 2100-12-31 21:59:59
unitValue match {
case YEAR => 130D // [1970, 2100]
case MONTH => 12D
case DAY => 31D
case HOUR => 24D
case MINUTE => 60D
case SECOND => timeOperandType match {
case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999]
case _ => 60D // [0, 59]
}
case QUARTER => 4D
case WEEK => 53D // [1, 53]
case MILLISECOND => timeOperandType match {
case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999]
case _ => 60D // [0, 59]
}
case MICROSECOND => timeOperandType match {
case TIMESTAMP | TIME => 60 * 1000D * 1000D // [0.000, 59.999]
case _ => 60D // [0, 59]
}
case DOW => 7D // [0, 6]
case DOY => 366D // [1, 366]
case EPOCH => timeOperandType match {
// the number of seconds since 1970-01-01 00:00:00 UTC
case TIMESTAMP | TIME => 130 * 24 * 60 * 60 * 1000D
case _ => 130 * 24 * 60 * 60D
}
case DECADE => 13D // The year field divided by 10
case CENTURY => 2D
case MILLENNIUM => 2D
case _ => cardOfCalcExpr(mq, calc, timeOperand)
}
case _ => cardOfCalcExpr(mq, calc, timeOperand)
}
} else if (call.getOperands.size() == 1) {
cardOfCalcExpr(mq, calc, call.getOperands.get(0))
} else {
if (rowCount != null) rowCount / 10 else null
}
if (distinctRowCount == null) {
null
} else {
RelMdUtil.numDistinctVals(distinctRowCount, rowCount)
}
}
}
/**
* Takes a bitmap representing a set of input references and extracts the
* ones that reference the group by columns in an aggregate.
*
*
* @param groupKey the original bitmap
* @param aggRel the aggregate
*/
def setAggChildKeys(
groupKey: ImmutableBitSet,
aggRel: Aggregate): (ImmutableBitSet, Array[AggregateCall]) = {
val childKeyBuilder = ImmutableBitSet.builder
val aggCalls = new mutable.ArrayBuffer[AggregateCall]()
val groupSet = aggRel.getGroupSet.toArray
val (auxGroupSet, otherAggCalls) = checkAndSplitAggCalls(aggRel)
val fullGroupSet = groupSet ++ auxGroupSet
// does not need to take keys in aggregate call into consideration if groupKey contains all
// groupSet element in aggregate
val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get)
groupKey.foreach(
bit =>
if (bit < fullGroupSet.length) {
childKeyBuilder.set(fullGroupSet(bit))
} else if (!containsAllAggGroupKeys) {
// getIndicatorCount return 0 if auxGroupSet is not empty
val agg = otherAggCalls.get(bit - (fullGroupSet.length + aggRel.getIndicatorCount))
aggCalls += agg
}
)
(childKeyBuilder.build(), aggCalls.toArray)
}
/**
* Takes a bitmap representing a set of input references and extracts the
* ones that reference the group by columns in an aggregate.
*
* @param groupKey the original bitmap
* @param aggRel the aggregate
*/
def setAggChildKeys(
groupKey: ImmutableBitSet,
aggRel: BatchExecGroupAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
setChildKeysOfAgg(groupKey, aggRel)
}
/**
* Takes a bitmap representing a set of input references and extracts the
* ones that reference the group by columns in an aggregate.
*
* @param groupKey the original bitmap
* @param aggRel the aggregate
*/
def setAggChildKeys(
groupKey: ImmutableBitSet,
aggRel: BatchExecWindowAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
setChildKeysOfAgg(groupKey, aggRel)
}
private def setChildKeysOfAgg(
groupKey: ImmutableBitSet,
aggRel: SingleRel): (ImmutableBitSet, Array[AggregateCall]) = {
val (aggCalls, fullGroupSet) = aggRel match {
case agg: BatchExecLocalSortWindowAggregate =>
// grouping + assignTs + auxGrouping
(agg.getAggCallList,
agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping)
case agg: BatchExecLocalHashWindowAggregate =>
// grouping + assignTs + auxGrouping
(agg.getAggCallList,
agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping)
case agg: BatchExecWindowAggregateBase =>
(agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
case agg: BatchExecGroupAggregateBase =>
(agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
case _ => throw new IllegalArgumentException(s"Unknown relnode type ${aggRel.getRelTypeName}")
}
// does not need to take keys in aggregate call into consideration if groupKey contains all
// groupSet element in aggregate
val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get)
val childKeyBuilder = ImmutableBitSet.builder
val aggs = new mutable.ArrayBuffer[AggregateCall]()
groupKey.foreach(
bit =>
if (bit < fullGroupSet.length) {
childKeyBuilder.set(fullGroupSet(bit))
} else if (!containsAllAggGroupKeys) {
val agg = aggCalls.get(bit - fullGroupSet.length)
aggs += agg
}
)
(childKeyBuilder.build(), aggs.toArray)
}
/**
* Takes a bitmap representing a set of local window aggregate references.
*
* global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
* local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
*
* Skips `assignTs` when mapping `groupKey` to `childKey`.
*
* @param groupKey the original bitmap
* @param globalWinAgg the global window aggregate
*/
def setChildKeysOfWinAgg(
groupKey: ImmutableBitSet,
globalWinAgg: BatchExecWindowAggregateBase): ImmutableBitSet = {
require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local agg!")
val childKeyBuilder = ImmutableBitSet.builder
groupKey.toArray.foreach { key =>
if (key < globalWinAgg.getGrouping.length) {
childKeyBuilder.set(key)
} else {
// skips `assignTs`
childKeyBuilder.set(key + 1)
}
}
childKeyBuilder.build()
}
/**
* Split groupKeys on Agregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase
* into keys on aggregate's groupKey and aggregate's aggregateCalls.
*
* @param agg the aggregate
* @param groupKey the original bitmap
*/
private[flink] def splitGroupKeysOnAggregate(
agg: SingleRel,
groupKey: ImmutableBitSet): (ImmutableBitSet, Array[AggregateCall]) = {
def removeAuxKey(
groupKey: ImmutableBitSet,
groupSet: Array[Int],
auxGroupSet: Array[Int]): ImmutableBitSet = {
if (groupKey.contains(ImmutableBitSet.of(groupSet: _*))) {
// remove auxGroupSet from groupKey if groupKey contain both full-groupSet
// and (partial-)auxGroupSet
groupKey.except(ImmutableBitSet.of(auxGroupSet: _*))
} else {
groupKey
}
}
agg match {
case rel: Aggregate =>
val (auxGroupSet, _) = FlinkRelOptUtil.checkAndSplitAggCalls(rel)
val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGroupSet.toArray, auxGroupSet)
(childKeyExcludeAuxKey, aggCalls)
case rel: BatchExecGroupAggregateBase =>
// set the bits as they correspond to the child input
val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
(childKeyExcludeAuxKey, aggCalls)
case rel: BatchExecWindowAggregateBase =>
val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
(childKeyExcludeAuxKey, aggCalls)
case _ => throw new IllegalArgumentException(
s"Unknown aggregate type: ${agg.getRelTypeName}.")
}
}
/**
* Shifts every [[RexInputRef]] in an expression higher than length of full grouping
* (for skips `assignTs`).
*
* global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
* local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
*
* @param predicate a RexNode
* @param globalWinAgg the global window aggregate
*/
def setChildPredicateOfWinAgg(
predicate: RexNode,
globalWinAgg: BatchExecWindowAggregateBase): RexNode = {
require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local agg!")
if (predicate == null) {
return null
}
// grouping + assignTs + auxGrouping
val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
// skips `assignTs`
RexUtil.shift(predicate, fullGrouping.length, 1)
}
/**
* Split a predicate on Aggregate into two parts, the first one is pushable part,
* the second one is rest part.
*
* @param agg Aggregate which to analyze
* @param predicate Predicate which to analyze
* @return a tuple, first element is pushable part, second element is rest part.
* Note, pushable condition will be converted based on the input field position.
*/
def splitPredicateOnAggregate(
agg: Aggregate,
predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
val fullGroupSet = checkAndGetFullGroupSet(agg)
splitPredicateOnAgg(fullGroupSet, agg, predicate)
}
/**
* Split a predicate on BatchExecGroupAggregateBase into two parts,
* the first one is pushable part, the second one is rest part.
*
* @param agg Aggregate which to analyze
* @param predicate Predicate which to analyze
* @return a tuple, first element is pushable part, second element is rest part.
* Note, pushable condition will be converted based on the input field position.
*/
def splitPredicateOnAggregate(
agg: BatchExecGroupAggregateBase,
predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
}
/**
* Split a predicate on WindowAggregateBatchExecBase into two parts,
* the first one is pushable part, the second one is rest part.
*
* @param agg Aggregate which to analyze
* @param predicate Predicate which to analyze
* @return a tuple, first element is pushable part, second element is rest part.
* Note, pushable condition will be converted based on the input field position.
*/
def splitPredicateOnAggregate(
agg: BatchExecWindowAggregateBase,
predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
}
private def splitPredicateOnAgg(
grouping: Array[Int],
agg: SingleRel,
predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
val notPushable = new util.ArrayList[RexNode]
val pushable = new util.ArrayList[RexNode]
val numOfGroupKey = grouping.length
RelOptUtil.splitFilters(
ImmutableBitSet.range(0, numOfGroupKey),
predicate,
pushable,
notPushable)
val rexBuilder = agg.getCluster.getRexBuilder
val childPred = if (pushable.isEmpty) {
None
} else {
// Converts a list of expressions that are based on the output fields of a
// Aggregate to equivalent expressions on the Aggregate's input fields.
val aggOutputFields = agg.getRowType.getFieldList
val aggInputFields = agg.getInput.getRowType.getFieldList
val adjustments = new Array[Int](aggOutputFields.size)
grouping.zipWithIndex foreach {
case (bit, index) => adjustments(index) = bit - index
}
val pushableConditions = pushable map {
pushCondition =>
pushCondition.accept(
new RelOptUtil.RexInputConverter(
rexBuilder,
aggOutputFields,
aggInputFields,
adjustments))
}
Option(RexUtil.composeConjunction(rexBuilder, pushableConditions, true))
}
val restPred = if (notPushable.isEmpty) {
None
} else {
Option(RexUtil.composeConjunction(rexBuilder, notPushable, true))
}
(childPred, restPred)
}
def getRankFunColumnIndex(rank: Rank): Int = {
rank match {
case r: LogicalRank => getRankFunColumnIndex(rank, outputRankFunColumn = true)
case r: FlinkLogicalRank => getRankFunColumnIndex(rank, r.outputRankFunColumn)
case r: BatchExecRank => getRankFunColumnIndex(rank, r.outputRankFunColumn)
case r: StreamExecRank => getRankFunColumnIndex(rank, r.outputRankFunColumn)
}
}
private def getRankFunColumnIndex(rank: Rank, outputRankFunColumn: Boolean): Int = {
if (outputRankFunColumn) {
require(rank.getRowType.getFieldCount == rank.getInput.getRowType.getFieldCount + 1)
rank.getRowType.getFieldCount - 1
} else {
require(rank.getRowType.getFieldCount == rank.getInput.getRowType.getFieldCount)
-1
}
}
def splitPredicateOnRank(
rank: Rank,
predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
val rankFunColumnIndex = getRankFunColumnIndex(rank)
if (predicate == null || predicate.isAlwaysTrue || rankFunColumnIndex < 0) {
return (Some(predicate), None)
}
val rankNodes = new util.ArrayList[RexNode]
val nonRankNodes = new util.ArrayList[RexNode]
RelOptUtil.splitFilters(
ImmutableBitSet.range(0, rankFunColumnIndex),
predicate,
nonRankNodes,
rankNodes)
val rexBuilder = rank.getCluster.getRexBuilder
val nonRankPred = if (nonRankNodes.isEmpty) {
None
} else {
Option(RexUtil.composeConjunction(rexBuilder, nonRankNodes, true))
}
val rankPred = if (rankNodes.isEmpty) {
None
} else {
Option(RexUtil.composeConjunction(rexBuilder, rankNodes, true))
}
(nonRankPred, rankPred)
}
def getRankRangeNdv(rankRange: RankRange): Double = rankRange match {
case r: ConstantRankRange => (r.rankEnd - r.rankStart + 1).toDouble
case _ => 100D // default value now
}
/** Splits a column set between left and right sets. */
def splitColumnsIntoLeftAndRight(
leftCount: Int,
columns: ImmutableBitSet): (ImmutableBitSet, ImmutableBitSet) = {
val leftBuilder = ImmutableBitSet.builder
val rightBuilder = ImmutableBitSet.builder
columns.foreach {
bit => if (bit < leftCount) leftBuilder.set(bit) else rightBuilder.set(bit - leftCount)
}
(leftBuilder.build, rightBuilder.build)
}
/**
* Estimates ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable.
*
* the value of `1.0 - math.exp(-0.1 * groupCount)` increases with groupCount
* from 0.095 until close to 1.0. when groupCount is 1, the formula result is 0.095,
* when groupCount is 2, the formula result is 0.18,
* when groupCount is 3, the formula result is 0.25.
* ...
*
* @param groupingLength grouping keys length of aggregate
* @return the ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable.
*/
def getAggregationRatioIfNdvUnavailable(groupingLength: Int): Double =
1.0 - math.exp(-0.1 * groupingLength)
/**
* Estimates outputRowCount of local aggregate.
*
* output rowcount of local agg is (1 - pow((1 - 1/x) , n/m)) * m * x, based on two assumption:
* 1. even distribution of all distinct data
* 2. even distribution of all data in each concurrent local agg worker
*
* @param parallelism number of concurrent worker of local aggregate
* @param inputRowCount rowcount of input node of aggregate.
* @param globalAggRowCount rowcount of output of global aggregate.
* @return outputRowCount of local aggregate.
*/
def getRowCountOfLocalAgg(
parallelism: Int,
inputRowCount: Double,
globalAggRowCount: Double): Double =
Math.min((1 - math.pow(1 - 1.0 / parallelism, inputRowCount / globalAggRowCount))
* globalAggRowCount * parallelism, inputRowCount)
/**
* Estimates new distinctRowCount of currentNode after it applies a condition.
* The estimation based on one assumption:
* even distribution of all distinct data
*
* @param rowCount rowcount of node.
* @param distinctRowCount distinct rowcount of node.
* @param selectivity selectivity of condition expression.
* @return new distinctRowCount
*/
def adaptNdvBasedOnSelectivity(
rowCount: Double,
distinctRowCount: Double,
selectivity: Double): Double = {
val ndv = Math.min(distinctRowCount, rowCount)
(1 - Math.pow(1 - selectivity, rowCount / ndv)) * ndv
}
/**
* Returns [[RexInputRef]] index set of projects corresponding to the given column index.
* The index will be set as -1 if the given column in project is not a [[RexInputRef]].
*/
def getInputRefIndices(index: Int, expand: Expand): util.Set[Int] = {
val inputRefs = new util.HashSet[Int]()
for (project <- expand.projects) {
project.get(index) match {
case inputRef: RexInputRef => inputRefs.add(inputRef.getIndex)
case _ => inputRefs.add(-1)
}
}
inputRefs
}
}