blob: cb040733dad74353df29b4f5a147cc6da0dbc827 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase}
import org.apache.flink.table.plan.stats.{ValueInterval, _}
import org.apache.flink.table.plan.util.FlinkRelOptUtil.checkAndSplitAggCalls
import org.apache.calcite.plan.RelOptPredicateList
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.metadata.RelMdUtil
import org.apache.calcite.rex._
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.sql.{SqlKind, SqlOperator}
import java.lang.{Double => JDouble}
import scala.collection.JavaConversions._
/**
* Estimates selectivity of rows meeting a filter predicate on an Aggregate.
*
* A filter predicate on an Aggregate may contain two parts:
* one is on group by columns, another is on aggregate call's result.
* The first part is handled by [[SelectivityEstimator]],
* the second part is handled by this Estimator.
*
* @param agg aggregate node
* @param mq Metadata query
*/
class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery)
extends RexVisitorImpl[Option[Double]](true) {
private val rexBuilder = agg.getCluster.getRexBuilder
// create SelectivityEstimator instance to use its default selectivity values
private val se = new SelectivityEstimator(agg, mq)
private[flink] val defaultAggCallSelectivity = Some(0.01d)
/**
* Gets AggregateCall from aggregate node
*/
def getSupportedAggCall(outputIdx: Int): Option[AggregateCall] = {
val (fullGrouping, aggCalls) = agg match {
case rel: Aggregate =>
val (auxGroupSet, otherAggCalls) = checkAndSplitAggCalls(rel)
(rel.getGroupSet.toArray ++ auxGroupSet, otherAggCalls)
case rel: BatchExecGroupAggregateBase =>
(rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList)
case rel: BatchExecLocalSortWindowAggregate =>
val fullGrouping = rel.getGrouping ++ Array(rel.inputTimestampIndex) ++ rel.getAuxGrouping
(fullGrouping, rel.getAggCallList)
case rel: BatchExecLocalHashWindowAggregate =>
val fullGrouping = rel.getGrouping ++ Array(rel.inputTimestampIndex) ++ rel.getAuxGrouping
(fullGrouping, rel.getAggCallList)
case rel: BatchExecWindowAggregateBase =>
(rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList)
case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!")
}
require(outputIdx >= fullGrouping.length)
val aggCallIdx = outputIdx - fullGrouping.length
val aggCall = if (aggCallIdx < aggCalls.length) aggCalls.get(aggCallIdx) else null
Option(aggCall).filter(isSupportedAggCall)
}
/**
* Returns whether the given aggCall is supported now
* TODO supports more
*/
def isSupportedAggCall(aggCall: AggregateCall): Boolean = {
aggCall.getAggregation.getKind match {
case SqlKind.SUM | SqlKind.MAX | SqlKind.MIN | SqlKind.AVG => true
case SqlKind.COUNT => aggCall.getArgList.size() == 1
case _ => false
}
}
/**
* Gets aggCall's interval through its argument's interval.
*/
def getAggCallInterval(aggCall: AggregateCall): ValueInterval = {
val aggInput = agg.getInput(0)
// assumes that the data is uniform distribution
def getRowCntPerGroup: Option[Double] = {
val inputRowCnt = mq.getRowCount(aggInput)
if (inputRowCnt == null) {
return None
}
val aggRowCnt = mq.getRowCount(agg)
if (aggRowCnt == null) {
return None
}
Some(inputRowCnt / aggRowCnt)
}
if (aggCall.getAggregation.getKind == SqlKind.COUNT) {
return getRowCntPerGroup match {
case Some(rowCntPerGroup) =>
// assumes the min count is half of the average count per group,
// the max count is double of the average count per group
ValueInterval(math.max(rowCntPerGroup / 2, 1), rowCntPerGroup * 2, true, true)
case _ => null
}
}
val argInterval = mq.getColumnInterval(aggInput, aggCall.getArgList.head)
argInterval match {
case null => null
case ValueInterval.infinite => ValueInterval.infinite
case ValueInterval.empty => ValueInterval.empty
case _ =>
val (min, includeMin) = argInterval match {
case hasLower: WithLower =>
(SelectivityEstimator.comparableToDouble(hasLower.lower), hasLower.includeLower)
case _ => (null, true)
}
val (max, includeMax) = argInterval match {
case hasUpper: WithUpper =>
(SelectivityEstimator.comparableToDouble(hasUpper.upper), hasUpper.includeUpper)
case _ => (null, true)
}
def getAggCallValue(v: JDouble): JDouble = {
if (v == null) {
return null
}
aggCall.getAggregation.getKind match {
case SqlKind.MAX | SqlKind.MIN | SqlKind.AVG => v
case SqlKind.SUM =>
getRowCntPerGroup match {
case Some(rowCntPerGroup) =>
// assume uniform distribution now
v * rowCntPerGroup
case _ => null
}
}
}
ValueInterval(getAggCallValue(min), getAggCallValue(max), includeMin, includeMax)
}
}
/**
* Returns a percentage of rows meeting a filter predicate on aggregate.
*
* @param predicate predicate whose selectivity is to be estimated against aggregate calls.
* @return estimated selectivity (between 0.0 and 1.0),
* or None if no reliable estimate can be determined.
*/
def evaluate(predicate: RexNode): Option[Double] = {
try {
if (predicate == null) {
Some(1.0)
} else {
val rexSimplify = new RexSimplify(
rexBuilder, RelOptPredicateList.EMPTY, true, RexUtil.EXECUTOR)
val simplifiedPredicate = rexSimplify.simplify(predicate)
if (simplifiedPredicate.isAlwaysTrue) {
Some(1.0)
} else if (simplifiedPredicate.isAlwaysFalse) {
Some(0.0)
} else {
simplifiedPredicate.accept(this)
}
}
} catch {
// if found unsupported operations, fallback
case _: Throwable => None
}
}
override def visitCall(call: RexCall): Option[Double] = {
val operands = call.getOperands
call.getOperator match {
case AND =>
val selectivity = operands.map(estimateOperand)
Some(selectivity.product)
case OR =>
val selectivity = operands.map(estimateOperand)
Some(math.min(1.0, selectivity.sum - selectivity.product))
case NOT =>
val selectivity = estimateOperand(operands.head)
Some(1.0 - selectivity)
case _ =>
estimateSinglePredicate(call)
}
}
def estimateOperand(operand: RexNode): Double = {
val subSelectivity = operand.accept(this)
if (subSelectivity != null) subSelectivity.getOrElse(1.0) else 1.0
}
/**
* Returns a percentage of rows meeting a single condition in Filter node.
*
* @param singlePredicate predicate whose selectivity is to be estimated against aggregate calls.
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
private def estimateSinglePredicate(singlePredicate: RexCall): Option[Double] = {
val operands = singlePredicate.getOperands
singlePredicate.getOperator match {
case EQUALS =>
estimateComparison(EQUALS, operands.head, operands.last)
case NOT_EQUALS =>
val selectivity = estimateComparison(EQUALS, operands.head, operands.last)
Some(1.0 - selectivity.getOrElse(1.0))
case GREATER_THAN =>
estimateComparison(GREATER_THAN, operands.head, operands.last)
case GREATER_THAN_OR_EQUAL =>
estimateComparison(GREATER_THAN_OR_EQUAL, operands.head, operands.last)
case LESS_THAN =>
estimateComparison(LESS_THAN, operands.head, operands.last)
case LESS_THAN_OR_EQUAL =>
estimateComparison(LESS_THAN_OR_EQUAL, operands.head, operands.last)
case RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC =>
Option(RelMdUtil.getSelectivityValue(singlePredicate))
case _ =>
se.defaultSelectivity
}
}
/**
* Returns a percentage of rows meeting a binary comparison expression containing two columns.
*
* @param op a binary comparison operator, including =, <=>, <, <=, >, >=
* @param left the left RexInputRef
* @param right the right RexInputRef
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column.
*/
private def estimateComparison(op: SqlOperator, left: RexNode, right: RexNode): Option[Double] = {
// if we can't handle some cases, uses SelectivityEstimator's default value
// (consistent with normal case).
// otherwise uses defaultAggCallSelectivity as default value.
if (!SelectivityEstimator.isSupportedComparisonType(left.getType) ||
!SelectivityEstimator.isSupportedComparisonType(right.getType)) {
val default = op match {
case EQUALS => se.defaultEqualsSelectivity
case _ => se.defaultComparisonSelectivity
}
return default
}
op match {
case EQUALS => (left, right) match {
case (i: RexInputRef, l: RexLiteral) => estimateEquals(i, l)
case (l: RexLiteral, i: RexInputRef) => estimateEquals(i, l)
case _ => se.defaultEqualsSelectivity
}
case LESS_THAN => (left, right) match {
case (i: RexInputRef, l: RexLiteral) => estimateComparison(LESS_THAN, i, l)
case (l: RexLiteral, i: RexInputRef) => estimateComparison(GREATER_THAN, i, l)
case _ => se.defaultComparisonSelectivity
}
case LESS_THAN_OR_EQUAL => (left, right) match {
case (i: RexInputRef, l: RexLiteral) => estimateComparison(LESS_THAN_OR_EQUAL, i, l)
case (l: RexLiteral, i: RexInputRef) => estimateComparison(GREATER_THAN_OR_EQUAL, i, l)
case _ => se.defaultComparisonSelectivity
}
case GREATER_THAN => (left, right) match {
case (i: RexInputRef, l: RexLiteral) => estimateComparison(GREATER_THAN, i, l)
case (l: RexLiteral, i: RexInputRef) => estimateComparison(LESS_THAN, i, l)
case _ => se.defaultComparisonSelectivity
}
case GREATER_THAN_OR_EQUAL => (left, right) match {
case (i: RexInputRef, l: RexLiteral) => estimateComparison(GREATER_THAN_OR_EQUAL, i, l)
case (l: RexLiteral, i: RexInputRef) => estimateComparison(LESS_THAN_OR_EQUAL, i, l)
case _ => se.defaultComparisonSelectivity
}
case _ => se.defaultComparisonSelectivity
}
}
/**
* Returns a percentage of rows meeting an equality (=) expression.
* e.g. count(a) = 10
*
* @param inputRef a RexInputRef
* @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column.
*/
private def estimateEquals(inputRef: RexInputRef, literal: RexLiteral): Option[Double] = {
if (literal.isNull) {
return se.defaultIsNullSelectivity
}
val aggCall = getSupportedAggCall(inputRef.getIndex)
if (!SelectivityEstimator.canConvertToNumericType(inputRef.getType) || aggCall.isEmpty) {
return se.defaultEqualsSelectivity
}
val aggCallInterval = getAggCallInterval(aggCall.get)
if (aggCallInterval == null) {
return se.defaultEqualsSelectivity
}
val convertedInterval = SelectivityEstimator.convertValueInterval(
aggCallInterval, inputRef.getType)
convertedInterval match {
case ValueInterval.infinite => se.defaultEqualsSelectivity
case ValueInterval.empty =>
// return defaultAggCallSelectivity instead of 0.0
defaultAggCallSelectivity
case i: FiniteValueInterval =>
val min = SelectivityEstimator.comparableToDouble(i.lower)
val max = SelectivityEstimator.comparableToDouble(i.upper)
if (ValueInterval.contains(i, SelectivityEstimator.literalToComparable(literal))) {
// the agg call interval is an estimated value, not a correct value.
// if `1.0 / (max - min)` is too small, uses default value
Some(math.max(defaultAggCallSelectivity.get, 1.0 / (max - min)))
} else {
defaultAggCallSelectivity
}
case _ => se.defaultEqualsSelectivity
}
}
/**
* Returns a percentage of rows meeting a binary comparison expression.
* e.g. sum(a) > 10
*
* @param op a binary comparison operator, including <, <=, >, >=
* @param inputRef a RexInputRef
* @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column.
*/
private def estimateComparison(
op: SqlOperator,
inputRef: RexInputRef,
literal: RexLiteral): Option[Double] = {
if (literal.isNull) {
throw new IllegalArgumentException("Numeric comparison does not support null literal here.")
}
val aggCall = getSupportedAggCall(inputRef.getIndex)
if (SelectivityEstimator.canConvertToNumericType(inputRef.getType) && aggCall.isDefined) {
estimateNumericComparison(op, aggCall.get, literal)
} else {
// TODO: It is difficult to support binary comparisons for non-numeric type
// without advanced statistics like histogram.
se.defaultComparisonSelectivity
}
}
/**
* Returns a percentage of rows meeting a binary numeric comparison expression.
* This method evaluate expression for Numeric/Boolean/Date/Time/Timestamp columns.
*
* @param op a binary comparison operator, including <, <=, >, >=
* @param aggCall an AggregateCall
* @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if no statistics collected for a given column.
*/
private def estimateNumericComparison(
op: SqlOperator,
aggCall: AggregateCall,
literal: RexLiteral): Option[Double] = {
val aggCallInterval = getAggCallInterval(aggCall)
if (aggCallInterval == null) {
return se.defaultComparisonSelectivity
}
aggCallInterval match {
case ValueInterval.infinite => se.defaultComparisonSelectivity
case ValueInterval.empty =>
// return defaultAggCallSelectivity instead of 0.0
defaultAggCallSelectivity
case _ =>
val (min, includeMin) = aggCallInterval match {
case hasLower: WithLower =>
(SelectivityEstimator.comparableToDouble(hasLower.lower), hasLower.includeLower)
case _ => (null, true)
}
val (max, includeMax) = aggCallInterval match {
case hasUpper: WithUpper =>
(SelectivityEstimator.comparableToDouble(hasUpper.upper), hasUpper.includeUpper)
case _ => (null, true)
}
val lit = SelectivityEstimator.literalToDouble(literal)
val (noOverlap, completeOverlap) = op match {
case LESS_THAN =>
val noOverlap = SelectivityEstimator.greaterThanOrEqualTo(min, lit)
val completeOverlap =
if (includeMax) SelectivityEstimator.lessThan(max, lit)
else SelectivityEstimator.lessThanOrEqualTo(max, lit)
(noOverlap, completeOverlap)
case LESS_THAN_OR_EQUAL =>
val noOverlap =
if (includeMin) SelectivityEstimator.greaterThan(min, lit)
else SelectivityEstimator.greaterThanOrEqualTo(min, lit)
val completeOverlap = SelectivityEstimator.lessThanOrEqualTo(max, lit)
(noOverlap, completeOverlap)
case GREATER_THAN =>
val noOverlap = SelectivityEstimator.lessThanOrEqualTo(max, lit)
val completeOverlap =
if (includeMin) SelectivityEstimator.greaterThan(min, lit)
else SelectivityEstimator.greaterThanOrEqualTo(min, lit)
(noOverlap, completeOverlap)
case GREATER_THAN_OR_EQUAL =>
val noOverlap =
if (includeMax) SelectivityEstimator.lessThan(max, lit)
else SelectivityEstimator.lessThanOrEqualTo(max, lit)
val completeOverlap = SelectivityEstimator.greaterThanOrEqualTo(min, lit)
(noOverlap, completeOverlap)
}
val selectivity = if (noOverlap) {
// return defaultAggCallSelectivity instead of 0.0
defaultAggCallSelectivity.get
} else if (completeOverlap) {
// return 1 - defaultAggCallSelectivity instead of 1.0
1.0 - defaultAggCallSelectivity.get
} else if (min != null && max != null) {
op match {
case LESS_THAN | LESS_THAN_OR_EQUAL => (lit - min) / (max - min)
case GREATER_THAN | GREATER_THAN_OR_EQUAL => (max - lit) / (max - min)
}
} else {
se.defaultComparisonSelectivity.get
}
Some(selectivity)
}
}
}