blob: 306f5e834c322b3222d963c3d2ba4f2487cde0e7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.expressions.ExpressionUtils._
import org.apache.flink.table.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
import org.apache.flink.table.plan.nodes.calcite.{Expand, LogicalWindowAggregate, Rank}
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.stats.ValueInterval
import org.apache.flink.table.plan.util.AggregateUtil._
import org.apache.flink.table.plan.util.FlinkRelMdUtil._
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, FlinkRelOptUtil}
import org.apache.flink.table.util.NodeResourceUtil
import org.apache.calcite.adapter.enumerable.EnumerableLimit
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex.{RexLiteral, RexNode}
import org.apache.calcite.util._
import java.lang.Double
import scala.collection.JavaConversions._
class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCount] {
def getDef: MetadataDef[BuiltInMetadata.RowCount] = BuiltInMetadata.RowCount.DEF
def getRowCount(rel: Expand, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)
def getRowCount(rel: Rank, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)
def getRowCount(rel: Aggregate, mq: RelMetadataQuery): Double = {
getRowCountOfAgg(rel, rel.getGroupSet, rel.getGroupSets.size(), mq)._1
}
/**
* Get output rowCount and input rowCount of agg
*
* @param rel agg relNode
* @param groupSet agg groupSet
* @param groupSetsSize agg groupSets count
* @param mq metadata query
* @return a tuple, the first element is output rowCount, second one is input rowCount
*/
private def getRowCountOfAgg(
rel: SingleRel,
groupSet: ImmutableBitSet,
groupSetsSize: Int,
mq: RelMetadataQuery): (Double, Double) = {
val childRowCount = mq.getRowCount(rel.getInput)
if (groupSet.cardinality() == 0) {
return (1.0, childRowCount)
}
// rowCount is the cardinality of the group by columns
val distinctRowCount = mq.getDistinctRowCount(rel.getInput, groupSet, null)
val groupCount = groupSet.cardinality()
val d: Double = if (distinctRowCount == null) {
NumberUtil.multiply(childRowCount,
FlinkRelMdUtil.getAggregationRatioIfNdvUnavailable(groupCount))
} else {
NumberUtil.min(distinctRowCount, childRowCount)
}
if (d != null) {
// Grouping sets multiply
(d * groupSetsSize, childRowCount)
} else {
(null, childRowCount)
}
}
def getRowCount(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): Double = {
getRowCountOfBatchExecAgg(rel, mq)
}
private def getRowCountOfBatchExecAgg(rel: SingleRel, mq: RelMetadataQuery): Double = {
val (grouping, isFinal, isMerge) = rel match {
case agg: BatchExecGroupAggregateBase =>
(ImmutableBitSet.of(agg.getGrouping: _*), agg.isFinal, agg.isMerge)
case windowAgg: BatchExecWindowAggregateBase =>
(ImmutableBitSet.of(windowAgg.getGrouping: _*), windowAgg.isFinal, windowAgg.isMerge)
case _ => throw new IllegalArgumentException(s"Unknown node type ${rel.getRelTypeName}!")
}
val ndvOfGroupKeysOnGlobalAgg: Double = if (grouping.isEmpty) {
1.0
} else {
// rowCount is the cardinality of the group by columns
val distinctRowCount = mq.getDistinctRowCount(rel.getInput, grouping, null)
val childRowCount = mq.getRowCount(rel.getInput)
if (distinctRowCount == null) {
if (isFinal && isMerge) {
// Avoid apply aggregation ratio twice when calculate row count of global agg
// which has local agg.
childRowCount
} else {
NumberUtil.multiply(childRowCount, getAggregationRatioIfNdvUnavailable(grouping.length))
}
} else {
NumberUtil.min(distinctRowCount, childRowCount)
}
}
if (isFinal) {
ndvOfGroupKeysOnGlobalAgg
} else {
val childRowCount = mq.getRowCount(rel.getInput)
val tableConfig = FlinkRelOptUtil.getTableConfig(rel)
val nParallelism = NodeResourceUtil.calOperatorParallelism(childRowCount, tableConfig.getConf)
if (nParallelism == 1) {
ndvOfGroupKeysOnGlobalAgg
} else if (grouping.isEmpty) {
// output rowcount of local agg is parallelism for agg which has no group keys
nParallelism.toDouble
} else {
val distinctRowCount = mq.getDistinctRowCount(rel.getInput, grouping, null)
if (distinctRowCount == null) {
ndvOfGroupKeysOnGlobalAgg
} else {
getRowCountOfLocalAgg(nParallelism, childRowCount, ndvOfGroupKeysOnGlobalAgg)
}
}
}
}
def getRowCount(rel: FlinkLogicalWindowAggregate, mq: RelMetadataQuery): Double = {
getRowCountOfWindowAgg(rel, rel.getWindow, rel.getGroupSet, mq)
}
def getRowCount(rel: LogicalWindowAggregate, mq: RelMetadataQuery): Double = {
getRowCountOfWindowAgg(rel, rel.getWindow, rel.getGroupSet, mq)
}
def getRowCount(rel: BatchExecWindowAggregateBase, mq: RelMetadataQuery): Double = {
val ndvOfGroupKeys = getRowCountOfBatchExecAgg(rel, mq)
val inputRowCount = mq.getRowCount(rel.getInput)
estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, rel.getWindow)
}
private def getRowCountOfWindowAgg(
windowAgg: SingleRel,
window: LogicalWindow,
grouping: ImmutableBitSet,
mq: RelMetadataQuery): Double = {
val (ndvOfGroupKeys, inputRowCount) = getRowCountOfAgg(windowAgg, grouping, 1, mq)
estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, window)
}
private def estimateRowCountOfWindowAgg(
ndv: Double,
inputRowCount: Double,
window: LogicalWindow): Double = {
if (ndv == null) {
null
} else {
// simply assume expand factor of TumblingWindow/SessionWindow/SlideWindowWithoutOverlap is 2
// SlideWindowWithOverlap is 4.
// Introduce expand factor here to distinguish output rowCount of normal agg with all kinds of
// window aggregates.
val expandFactorOfTumblingWindow = 2D
val expandFactorOfNoOverLapSlidingWindow = 2D
val expandFactorOfOverLapSlidingWindow = 4D
val expandFactorOfSessionWindow = 2D
window match {
case TumblingGroupWindow(_, _, size) if isTimeIntervalLiteral(size) =>
Math.min(expandFactorOfTumblingWindow * ndv, inputRowCount)
case SlidingGroupWindow(_, _, size, slide) if isTimeIntervalLiteral(size) =>
val sizeV = asLong(size)
val slideV = asLong(slide)
if (sizeV > slideV) {
// only slideWindow which has overlap may generates more records than input
expandFactorOfOverLapSlidingWindow * ndv
} else {
Math.min(expandFactorOfNoOverLapSlidingWindow * ndv, inputRowCount)
}
case _ => Math.min(expandFactorOfSessionWindow * ndv, inputRowCount)
}
}
}
def getRowCount(rel: BatchExecOverAggregate, mq: RelMetadataQuery): Double =
getRowCountOfOverWindow(rel, mq)
def getRowCount(rel: Window, mq: RelMetadataQuery): Double =
getRowCountOfOverWindow(rel, mq)
private def getRowCountOfOverWindow(overWindow: SingleRel, mq: RelMetadataQuery): Double =
mq.getRowCount(overWindow.getInput)
def getRowCount(join: Join, mq: RelMetadataQuery): Double = {
val leftChild = join.getLeft
val rightChild = join.getRight
val leftRowCount = mq.getRowCount(leftChild)
val rightRowCount = mq.getRowCount(rightChild)
if (leftRowCount == null || rightRowCount == null) {
return null
}
val joinInfo = JoinInfo.of(leftChild, rightChild, join.getCondition)
if (joinInfo.leftSet().nonEmpty) {
val innerJoinRowCount = getEquiInnerJoinRowCount(join, mq, leftRowCount, rightRowCount)
require(innerJoinRowCount != null)
// Make sure outputRowCount won't be too small based on join type.
join.getJoinType match {
case JoinRelType.INNER => innerJoinRowCount
case JoinRelType.LEFT =>
// All rows from left side should be in the result.
math.max(leftRowCount, innerJoinRowCount)
case JoinRelType.RIGHT =>
// All rows from right side should be in the result.
math.max(rightRowCount, innerJoinRowCount)
case JoinRelType.FULL =>
// T(A FULL JOIN B) = T(A LEFT JOIN B) + T(A RIGHT JOIN B) - T(A INNER JOIN B)
math.max(leftRowCount, innerJoinRowCount) +
math.max(rightRowCount, innerJoinRowCount) - innerJoinRowCount
}
} else {
val rexBuilder = join.getCluster.getRexBuilder
val crossJoin = copyJoinWithNewCondition(join, rexBuilder.makeLiteral(true))
val selectivity = mq.getSelectivity(crossJoin, join.getCondition)
(leftRowCount * rightRowCount) * selectivity
}
}
private def getEquiInnerJoinRowCount(
join: Join,
mq: RelMetadataQuery,
leftRowCount: Double,
rightRowCount: Double): Double = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val leftChild = join.getLeft
val rightChild = join.getRight
val rexBuilder = join.getCluster.getRexBuilder
val condition = join.getCondition
val joinInfo = JoinInfo.of(leftChild, rightChild, condition)
// the leftKeys length equals to rightKeys, so it's ok to only check leftKeys length
require(joinInfo.leftKeys.nonEmpty)
val joinKeyDisjoint = joinInfo.leftKeys.zip(joinInfo.rightKeys).exists {
case (leftKey, rightKey) =>
val leftInterval = fmq.getColumnInterval(leftChild, leftKey)
val rightInterval = fmq.getColumnInterval(rightChild, rightKey)
if (leftInterval != null && rightInterval != null) {
!ValueInterval.isIntersected(leftInterval, rightInterval)
} else {
false
}
}
// One of the join key pairs is disjoint, thus the two sides of join is disjoint.
if (joinKeyDisjoint) {
return 0D
}
val leftKeySet = joinInfo.leftSet()
val rightKeySet = joinInfo.rightSet()
val leftNdv = fmq.getDistinctRowCount(leftChild, leftKeySet, null)
val rightNdv = fmq.getDistinctRowCount(rightChild, rightKeySet, null)
// estimate selectivity of non-equi
val selectivityOfNonEquiPred: Double = if (joinInfo.isEqui) {
1D
} else {
val nonEquiPred = joinInfo.getRemaining(rexBuilder)
val equiPred = RelMdUtil.minusPreds(rexBuilder, condition, nonEquiPred)
val joinWithOnlyEquiPred = copyJoinWithNewCondition(join, equiPred)
fmq.getSelectivity(joinWithOnlyEquiPred, nonEquiPred)
}
if (leftNdv != null && rightNdv != null) {
// selectivity of equi part is 1 / Max(leftNdv, rightNdv)
val selectivityOfEquiPred = Math.min(1D, 1D / Math.max(leftNdv, rightNdv))
return leftRowCount * rightRowCount * selectivityOfEquiPred * selectivityOfNonEquiPred
}
val leftKeysAreUnique = fmq.areColumnsUnique(leftChild, leftKeySet)
val rightKeysAreUnique = fmq.areColumnsUnique(rightChild, rightKeySet)
if (leftKeysAreUnique != null && rightKeysAreUnique != null &&
(leftKeysAreUnique || rightKeysAreUnique)) {
val outputRowCount = if (leftKeysAreUnique && rightKeysAreUnique) {
// if both leftKeys and rightKeys are both unique,
// rowCount = Min(leftRowCount) * selectivity of non-equi
Math.min(leftRowCount, rightRowCount) * selectivityOfNonEquiPred
} else if (leftKeysAreUnique) {
rightRowCount * selectivityOfNonEquiPred
} else {
leftRowCount * selectivityOfNonEquiPred
}
return outputRowCount
}
// if joinCondition has no ndv stats and no uniqueKeys stats,
// rowCount = (leftRowCount + rightRowCount) * join condition selectivity
val crossJoin = copyJoinWithNewCondition(join, rexBuilder.makeLiteral(true))
val selectivity = fmq.getSelectivity(crossJoin, condition)
(leftRowCount + rightRowCount) * selectivity
}
private def copyJoinWithNewCondition(join: Join, newCondition: RexNode): Join = {
join.copy(
join.getTraitSet,
newCondition,
join.getLeft,
join.getRight,
join.getJoinType,
join.isSemiJoinDone)
}
def getRowCount(rel: SemiJoin, mq: RelMetadataQuery): Double = {
val semiJoinSelectivity = FlinkRelMdUtil.makeSemiJoinSelectivityRexNode(mq, rel)
NumberUtil.multiply(
mq.getSelectivity(rel.getLeft, semiJoinSelectivity),
mq.getRowCount(rel.getLeft))
}
/** Catch-all implementation for
* [[BuiltInMetadata.RowCount#getRowCount()]],
* invoked using reflection.
*
* @see org.apache.calcite.rel.metadata.RelMetadataQuery#getRowCount(RelNode)
*/
def getRowCount(rel: RelNode, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)
def getRowCount(subset: RelSubset, mq: RelMetadataQuery): Double = {
if (!Bug.CALCITE_1048_FIXED) {
return mq.getRowCount(Util.first(subset.getBest, subset.getOriginal))
}
val v = subset.getRels.foldLeft(null.asInstanceOf[Double]) {
(min, r) =>
try {
NumberUtil.min(min, mq.getRowCount(r))
} catch {
// ignore this rel; there will be other, non-cyclic ones
case e: CyclicMetadataException => min
case e: Throwable =>
e.printStackTrace()
min
}
}
// if set is empty, estimate large
Util.first(v, 1e6d)
}
def getRowCount(rel: Union, mq: RelMetadataQuery): Double = {
val rowCounts = rel.getInputs.map(mq.getRowCount)
if (rowCounts.contains(null)) {
null
} else {
rowCounts.foldLeft(0D)(_ + _)
}
}
def getRowCount(rel: Intersect, mq: RelMetadataQuery): Double = {
rel.getInputs.foldLeft(null.asInstanceOf[Double])((res, r) => {
val partialRowCount = mq.getRowCount(r)
if (res == null || (partialRowCount != null && partialRowCount < res)) {
partialRowCount
} else {
res
}
})
}
def getRowCount(rel: Minus, mq: RelMetadataQuery): Double = {
rel.getInputs.foldLeft(null.asInstanceOf[Double])((res, r) => {
val partialRowCount = mq.getRowCount(r)
if (res == null || (partialRowCount != null && partialRowCount < res)) {
partialRowCount
} else {
res
}
})
}
def getRowCount(rel: Filter, mq: RelMetadataQuery): Double =
RelMdUtil.estimateFilteredRows(rel.getInput, rel.getCondition, mq)
def getRowCount(rel: Calc, mq: RelMetadataQuery): Double =
RelMdUtil.estimateFilteredRows(rel.getInput, rel.getProgram, mq)
def getRowCount(rel: Project, mq: RelMetadataQuery): Double = mq.getRowCount(rel.getInput)
def getRowCount(rel: Sort, mq: RelMetadataQuery): Double = {
var rowCount = mq.getRowCount(rel.getInput)
if (rowCount == null) {
return null
}
val offset = if (rel.offset == null) 0 else RexLiteral.intValue(rel.offset)
rowCount = Math.max(rowCount - offset, 0D)
if (rel.fetch != null) {
val limit = RexLiteral.intValue(rel.fetch)
if (limit < rowCount) {
return limit.toDouble
}
}
rowCount
}
def getRowCount(rel: EnumerableLimit, mq: RelMetadataQuery): Double = {
var rowCount: Double = mq.getRowCount(rel.getInput)
if (rowCount == null) {
return null
}
val offset = if (rel.offset == null) 0 else RexLiteral.intValue(rel.offset)
rowCount = Math.max(rowCount - offset, 0D)
if (rel.fetch != null) {
val limit = RexLiteral.intValue(rel.fetch)
if (limit < rowCount) {
return limit.toDouble
}
}
rowCount
}
def getRowCount(rel: SingleRel, mq: RelMetadataQuery): Double = mq.getRowCount(rel.getInput)
def getRowCount(rel: TableScan, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)
def getRowCount(rel: Values, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)
}
object FlinkRelMdRowCount {
private val INSTANCE = new FlinkRelMdRowCount
val SOURCE: RelMetadataProvider = ReflectiveRelMetadataProvider.reflectiveSource(
BuiltInMethod.ROW_COUNT.method, INSTANCE)
}