blob: 39cd1ecbc77e86e74c9ef46261027d7447fef2c8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.api.TableException
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.plan.metadata.FlinkMetadata.ColumnInterval
import org.apache.flink.table.plan.nodes.calcite.{Expand, LogicalWindowAggregate, Rank}
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalSnapshot, FlinkLogicalWindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream._
import org.apache.flink.table.plan.schema.FlinkRelOptTable
import org.apache.flink.table.plan.stats._
import org.apache.flink.table.plan.util.FlinkRelOptUtil._
import org.apache.flink.table.plan.util.{ColumnIntervalUtil, ConstantRankRange, FlinkRelMdUtil, VariableRankRange}
import org.apache.flink.util.Preconditions
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{AbstractRelNode, RelNode, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind._
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.{SqlBinaryOperator, SqlKind}
import org.apache.calcite.util.Util
import java.lang.{Boolean => JBool}
import scala.collection.JavaConversions._
/**
* FlinkRelMdColumnInterval supplies a default implementation of
* [[FlinkRelMetadataQuery.getColumnInterval]] for the standard logical algebra.
*/
class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
override def getDef: MetadataDef[ColumnInterval] = FlinkMetadata.ColumnInterval.DEF
/**
* Gets interval of the given column in TableScan.
*
* @param ts TableScan RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in TableScan
*/
def getColumnInterval(ts: TableScan, mq: RelMetadataQuery, index: Int): ValueInterval = {
val relOptTable = ts.getTable.asInstanceOf[FlinkRelOptTable]
val fieldNames = relOptTable.getRowType.getFieldNames
Preconditions.checkArgument(index >= 0 && index < fieldNames.size())
val fieldName = fieldNames.get(index)
val statistic = relOptTable.getFlinkStatistic
val colStats = statistic.getColumnStats(fieldName)
if (colStats != null) {
if (colStats.min == null && colStats.max == null) {
null
} else {
ValueInterval(colStats.min, colStats.max)
}
} else {
null
}
}
/**
* Gets interval of the given column in FlinkLogicalSnapshot.
* TODO implements it.
*
* @param snapshot Snapshot RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in TableScan
*/
def getColumnInterval(
snapshot: FlinkLogicalSnapshot,
mq: RelMetadataQuery,
index: Int): ValueInterval = null
/**
* Gets interval of the given column in Project.
*
* Note: Only support the simple RexNode, e.g RexInputRef.
*
* @param project Project RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Project
*/
def getColumnInterval(project: Project, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val projects = project.getProjects
Preconditions.checkArgument(index >= 0 && index < projects.size())
projects.get(index) match {
case inputRef: RexInputRef => fmq.getColumnInterval(project.getInput, inputRef.getIndex)
case literal: RexLiteral =>
val literalValue = getLiteralValue(literal)
if (literalValue == null) {
ValueInterval.empty
} else {
ValueInterval(literalValue, literalValue)
}
case rexCall: RexCall if rexCall.op.isInstanceOf[SqlBinaryOperator] =>
getRexNodeInterval(rexCall, project, mq)
case _ => null
}
}
/**
* Gets interval of the given column in Exchange.
*
* @param exchange Exchange RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Exchange
*/
def getColumnInterval(exchange: Exchange, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getColumnInterval(exchange.getInput, index)
}
/**
* Gets interval of the given column in Union.
*
* @param union Union RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Union
*/
def getColumnInterval(union: Union, mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfUnion(union, mq, index)
/**
* Gets interval of the given column in Union.
*
* @param union Union RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch Union
*/
private def estimateColumnIntervalOfUnion(
union: AbstractRelNode,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val subIntervals = union
.getInputs
.map(fmq.getColumnInterval(_, index))
subIntervals.reduceLeft(ValueInterval.union)
}
/**
* Gets interval of the given column in Values.
*
* @param values Values RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Values
*/
def getColumnInterval(values: Values, mq: RelMetadataQuery, index: Int): ValueInterval = {
val tuples = values.tuples
if (tuples.isEmpty) {
EmptyValueInterval
} else {
val vals = tuples.map(tuple => getLiteralValue(tuple.get(index))).filter(_ != null)
if (vals.isEmpty) {
EmptyValueInterval
} else {
vals.map(literal => ValueInterval(literal, literal)).reduceLeft(ValueInterval.union)
}
}
}
/**
* Gets interval of the given column in Filter.
*
* @param filter Filter RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Filter
*/
def getColumnInterval(filter: Filter, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputValueInterval = fmq.getColumnInterval(filter.getInput, index)
FlinkRelMdColumnInterval.getColumnIntervalWithFilter(
Option(inputValueInterval),
filter.getCondition,
index,
filter.getCluster.getRexBuilder)
}
/**
* Gets interval of the given column in batch Calc.
*
* @param calc Filter RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Filter
*/
def getColumnInterval(calc: Calc, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val rexProgram = calc.getProgram
val project = rexProgram.split().left.get(index)
getColumnInterval(calc, fmq, project)
}
/**
* Calculate interval of column which results from the given rex node in calc.
* Note that this function is called by function above, and is reclusive in case
* of "AS" rex call, and is private, too.
*/
private def getColumnInterval(
calc: Calc,
mq: RelMetadataQuery,
rex: RexNode): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
rex match {
case rex: RexCall if rex.getKind == SqlKind.AS =>
getColumnInterval(calc, fmq, rex.getOperands.head)
case inputRef: RexInputRef =>
val rexProgram = calc.getProgram
val sourceFieldIndex = inputRef.getIndex
val inputValueInterval = fmq.getColumnInterval(calc.getInput, sourceFieldIndex)
val condition = rexProgram.getCondition
if (condition != null) {
val predicate = rexProgram.expandLocalRef(rexProgram.getCondition)
FlinkRelMdColumnInterval.getColumnIntervalWithFilter(
Option(inputValueInterval),
predicate,
sourceFieldIndex,
calc.getCluster.getRexBuilder)
} else {
inputValueInterval
}
case literal: RexLiteral =>
val literalValue = getLiteralValue(literal)
if (literalValue == null) {
ValueInterval.empty
} else {
ValueInterval(literalValue, literalValue)
}
case rexCall: RexCall =>
getRexNodeInterval(rexCall, calc, mq)
case _ => null
}
}
private def getRexNodeInterval(
rexNode: RexNode,
baseNode: SingleRel,
mq: RelMetadataQuery): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
rexNode match {
case inputRef: RexInputRef =>
fmq.getColumnInterval(baseNode.getInput, inputRef.getIndex)
case literal: RexLiteral =>
val literalValue = getLiteralValue(literal)
if (literalValue == null) {
ValueInterval.empty
} else {
ValueInterval(literalValue, literalValue)
}
case caseCall: RexCall if caseCall.getKind == SqlKind.CASE =>
// compute all the possible result values of this case when clause,
// the result values is the value interval
val operands = caseCall.getOperands
val operandCount = operands.size()
val possibleValueIntervals = operands.indices
// filter expressions which is condition
.filter(i => i % 2 != 0 || i == operandCount - 1)
.map(operands(_))
.map(getRexNodeInterval(_, baseNode, mq))
possibleValueIntervals.reduceLeft(ValueInterval.union)
case ifCall: RexCall if ifCall.getOperator == ScalarSqlFunctions.IF =>
// compute all the possible result values of this IF clause,
// the result values is the value interval
val trueValueInterval = getRexNodeInterval(ifCall.getOperands.get(1), baseNode, mq)
val falseValueInterval = getRexNodeInterval(ifCall.getOperands.get(2), baseNode, mq)
ValueInterval.union(trueValueInterval, falseValueInterval)
case rexCall: RexCall if rexCall.op.isInstanceOf[SqlBinaryOperator] =>
val leftValueInterval = getRexNodeInterval(rexCall.operands.get(0), baseNode, mq)
val rightValueInterval = getRexNodeInterval(rexCall.operands.get(1), baseNode, mq)
ColumnIntervalUtil.getValueIntervalOfRexCall(
rexCall,
leftValueInterval,
rightValueInterval)
case _ => null
}
}
/**
* Gets intervals of the given column in Join.
*
* @param join Join RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Join
*/
def getColumnInterval(join: Join, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val joinCondition = join.getCondition
val nLeftColumns = join.getLeft.getRowType.getFieldCount
val inputValueInterval = if (index < nLeftColumns) {
fmq.getColumnInterval(join.getLeft, index)
} else {
fmq.getColumnInterval(join.getRight, index - nLeftColumns)
}
// TODO if column at index position is EuqiJoinKey in a Inner Join, its interval is
// origin interval intersect interval in the pair joinJoinKey.
// for example, if join is a InnerJoin with condition l.A = r.A
// the valueInterval of l.A is the intersect of l.A with r.A
if (joinCondition == null || joinCondition.isAlwaysTrue) {
inputValueInterval
} else {
FlinkRelMdColumnInterval.getColumnIntervalWithFilter(
Option(inputValueInterval),
joinCondition,
index,
join.getCluster.getRexBuilder)
}
}
/**
* Gets intervals of the given column in Aggregates.
*
* @param aggregate Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Aggregate
*/
def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets intervals of the given column in FlinkLogicalWindowAggregate.
*
* @param agg Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in FlinkLogicalWindowAggregate
*/
def getColumnInterval(
agg: FlinkLogicalWindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets intervals of the given column in LogicalWindowAggregate.
*
* @param agg Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in LogicalWindowAggregate
*/
def getColumnInterval(
agg: LogicalWindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets intervals of the given column in WindowAggregateBatchExecBase.
*
* @param agg Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in LogicalWindowAggregate
*/
def getColumnInterval(
agg: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets intervals of the given column in batch OverWindowAggregate.
*
* @param aggregate Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch OverWindowAggregate
*/
def getColumnInterval(
aggregate: BatchExecOverAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = getColumnIntervalOfOverWindow(aggregate, mq, index)
/**
* Gets intervals of the given column in calcite window.
*
* @param window Window RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in window
*/
def getColumnInterval(
window: Window,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
getColumnIntervalOfOverWindow(window, mq, index)
}
private def getColumnIntervalOfOverWindow(
overWindow: SingleRel,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val input = overWindow.getInput()
val fieldsCountOfInput = input.getRowType.getFieldCount
if (index < fieldsCountOfInput) {
fmq.getColumnInterval(input, index)
} else {
// cannot estimate aggregate function calls columnInterval.
null
}
}
/**
* Gets intervals of the given column in batch Aggregate.
*
* @param aggregate Aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch Aggregate
*/
def getColumnInterval(
aggregate: BatchExecGroupAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
def getColumnInterval(
aggregate: StreamExecGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
def getColumnInterval(
aggregate: StreamExecLocalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
def getColumnInterval(
aggregate: StreamExecGlobalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
// global aggregate can't estimate the column interval of agg arguments,
// and the global groupingSet mapping is same to index, so delegate it to local aggregate
fmq.getColumnInterval(aggregate.getInput, index)
}
def getColumnInterval(
aggregate: StreamExecGroupWindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
private def estimateColumnIntervalOfAggregate(
aggregate: SingleRel,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val groupSet = aggregate match {
case agg: StreamExecGroupAggregate => agg.getGroupings
case agg: StreamExecLocalGroupAggregate => agg.getGroupings
case agg: StreamExecIncrementalGroupAggregate => agg.shuffleKey
case agg: StreamExecGroupWindowAggregate => agg.getGroupings
case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
case agg: Aggregate => checkAndGetFullGroupSet(agg)
case agg: BatchExecLocalSortWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping
case agg: BatchExecLocalHashWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping
case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
// do not match StreamExecGlobalGroupAggregate
}
if (index < groupSet.length) {
// estimates group keys according to the input relNodes.
val sourceFieldIndex = groupSet(index)
fmq.getColumnInterval(aggregate.getInput, sourceFieldIndex)
} else {
// cannot estimate aggregate function calls columnInterval.
val aggCallIndex = index - groupSet.length
val aggregateCall = aggregate match {
case agg: StreamExecGroupAggregate
if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: StreamExecLocalGroupAggregate
if agg.aggInfoList.getActualAggregateCalls.length > aggCallIndex =>
agg.aggInfoList.getActualAggregateCalls(aggCallIndex)
case agg: StreamExecIncrementalGroupAggregate
if agg.partialAggInfoList.getActualAggregateCalls.length > aggCallIndex =>
agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex)
case agg: StreamExecGroupWindowAggregate
if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchExecGroupAggregateBase
if agg.aggregateCalls.length > aggCallIndex =>
agg.aggregateCalls(aggCallIndex)
case agg: Aggregate
if agg.getAggCallList.length > aggCallIndex =>
agg.getAggCallList.get(aggCallIndex)
case agg: BatchExecWindowAggregateBase
if agg.aggregateCalls.length > aggCallIndex =>
agg.aggregateCalls(aggCallIndex)
// do not match StreamExecGlobalGroupAggregate
case _ => null
}
if (aggregateCall != null) {
aggregateCall.getAggregation.getKind match {
case SUM | SUM0 =>
val inputInterval: ValueInterval = fmq.getColumnInterval(
aggregate.getInput,
aggregateCall.getArgList.get(0))
if (inputInterval != null) {
inputInterval match {
case withLower: WithLower if withLower.lower.isInstanceOf[Number] =>
if (withLower.lower.asInstanceOf[Number].doubleValue() >= 0.0) {
RightSemiInfiniteValueInterval(withLower.lower, withLower.includeLower)
} else {
null.asInstanceOf[ValueInterval]
}
case withUpper: WithUpper if withUpper.upper.isInstanceOf[Number] =>
if (withUpper.upper.asInstanceOf[Number].doubleValue() <= 0.0) {
LeftSemiInfiniteValueInterval(withUpper.upper, withUpper.includeUpper)
} else {
null
}
case _ => null
}
} else {
null
}
case COUNT => RightSemiInfiniteValueInterval(0, includeLower = true)
// todo add more built-in agg function
case _ => null
}
} else {
null
}
}
}
/**
* Gets intervals of the given column of Sort.
*
* @param sort Sort to analyze
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in Sort
*/
def getColumnInterval(sort: Sort, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getColumnInterval(sort.getInput, index)
}
/**
* Gets intervals of the given column of Expand.
*
* @param expand expand to analyze
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch sort
*/
def getColumnInterval(
expand: Expand,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val intervals = expand.projects.flatMap { project =>
project(index) match {
case inputRef: RexInputRef =>
Some(fmq.getColumnInterval(expand.getInput, inputRef.getIndex))
case l: RexLiteral if l.getTypeName eq SqlTypeName.DECIMAL =>
val v = l.getValueAs(classOf[java.lang.Long])
Some(ValueInterval(v, v))
case l: RexLiteral if l.getValue == null =>
None
case p@_ =>
throw new TableException(s"Column interval can't handle $p type in expand.")
}
}
if (intervals.contains(null)) {
// null union any value interval is null
null
} else {
intervals.reduce((a, b) => ValueInterval.union(a, b))
}
}
/**
* Gets intervals of the given column of Rank.
*
* @param rank [[Rank]] instance to analyze
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column in batch Rank
*/
def getColumnInterval(
rank: Rank,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val rankFunColumnIndex = FlinkRelMdUtil.getRankFunColumnIndex(rank)
if (index == rankFunColumnIndex) {
rank.rankRange match {
case r: ConstantRankRange => ValueInterval(r.rankStart, r.rankEnd)
case v: VariableRankRange =>
val interval = fmq.getColumnInterval(rank.getInput, v.rankEndIndex)
interval match {
case hasUpper: WithUpper =>
val lower = ColumnIntervalUtil.convertStringToNumber("1", hasUpper.upper.getClass)
lower match {
case Some(l) =>
ValueInterval(l, hasUpper.upper, includeUpper = hasUpper.includeUpper)
case _ => null
}
case _ => null
}
}
} else {
fmq.getColumnInterval(rank.getInput, index)
}
}
/**
* Gets intervals of the given column of RelSubset.
*
* @param subset RelSubset to analyze
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return If exist best relNode, then transmit to it, else transmit to the original relNode
*/
def getColumnInterval(subset: RelSubset, mq: RelMetadataQuery, index: Int): ValueInterval = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getColumnInterval(Util.first(subset.getBest, subset.getOriginal), index)
}
/**
* Catches-all rule when none of the others apply.
*
* @param rel RelNode to analyze
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return Always returns null
*/
def getColumnInterval(rel: RelNode, mq: RelMetadataQuery, index: Int): ValueInterval = null
}
object FlinkRelMdColumnInterval {
private val INSTANCE = new FlinkRelMdColumnInterval
val SOURCE: RelMetadataProvider = ReflectiveRelMetadataProvider.reflectiveSource(
FlinkMetadata.ColumnInterval.METHOD, INSTANCE)
/**
* Calculate the interval of column which is referred in predicate expression, and intersect the
* result with the origin interval of the column.
*
* e.g for condition $1 <= 2 and $1 >= -1
* the interval of $1 is originInterval intersect with [-1, 2]
*
* for condition: $1 <= 2 and not ($1 < -1 or $2 is true),
* the interval of $1 is originInterval intersect with (-Inf, -1]
*
* for condition $1 <= 2 or $1 > -1
* the interval of $1 is (originInterval intersect with (-Inf, 2]) union
* (originInterval intersect with (-1, Inf])
*
* @param originInterval origin interval of the column
* @param predicate the predicate expression
* @param inputRef the index of the given column
* @param rexBuilder RexBuilder instance to analyze the predicate expression
* @return
*/
def getColumnIntervalWithFilter(
originInterval: Option[ValueInterval],
predicate: RexNode,
inputRef: Int,
rexBuilder: RexBuilder): ValueInterval = {
val isRelated = (r: RexNode)=> r.accept(new ColumnRelatedVisitor(inputRef))
val relatedSubRexNode = partition(predicate, rexBuilder, isRelated)._1
val beginInterval = originInterval match {
case Some(interval) => interval
case _ => ValueInterval.infinite
}
relatedSubRexNode match {
case Some(rexNode) =>
val orParts = RexUtil.flattenOr(Vector(RexUtil.toDnf(rexBuilder, rexNode)))
val interval = orParts.map(or => {
val andParts = RexUtil.flattenAnd(Vector(or))
andParts.map(and => columnIntervalOfSinglePredicate(and))
.filter(_ != null)
.foldLeft(beginInterval)(ValueInterval.intersect)
}).reduceLeft(ValueInterval.union)
if (interval == ValueInterval.infinite) null else interval
case None => beginInterval
}
}
private def columnIntervalOfSinglePredicate(condition: RexNode): ValueInterval = {
val convertedCondition = condition.asInstanceOf[RexCall]
if (convertedCondition == null || convertedCondition.operands.size() != 2) {
null
} else {
val (literalValue, op) = (convertedCondition.operands.head, convertedCondition.operands.last)
match {
case (_: RexInputRef, literal: RexLiteral) =>
(getLiteralValue(literal), convertedCondition.getKind)
case (rex: RexCall, literal: RexLiteral) if rex.getKind == SqlKind.AS =>
(getLiteralValue(literal), convertedCondition.getKind)
case (literal: RexLiteral, _: RexInputRef) =>
(getLiteralValue(literal), convertedCondition.getKind.reverse())
case (literal: RexLiteral, rex: RexCall) if rex.getKind == SqlKind.AS =>
(getLiteralValue(literal), convertedCondition.getKind.reverse())
case _ => (null, null)
}
if (op == null || literalValue == null) {
null
} else {
op match {
case EQUALS => ValueInterval(literalValue, literalValue)
case LESS_THAN => ValueInterval(null, literalValue, includeUpper = false)
case LESS_THAN_OR_EQUAL => ValueInterval(null, literalValue)
case GREATER_THAN => ValueInterval(literalValue, null, includeLower = false)
case GREATER_THAN_OR_EQUAL => ValueInterval(literalValue, null)
case _ => null
}
}
}
}
}