blob: 7f57978680524b60d9639be73ade279a977629f5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.util
import org.apache.flink.table.api.{TableConfig, TableConfigOptions}
import org.apache.flink.table.functions.sql.internal.SqlAuxiliaryGroupAggFunction
import org.apache.flink.table.validate.{BuiltInFunctionCatalog, FunctionCatalog}
import org.apache.calcite.plan.RelOptUtil
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.externalize.RelWriterImpl
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlExplainLevel
import org.apache.calcite.sql.SqlKind._
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.util.Pair
import java.io.{PrintWriter, StringWriter}
import java.lang.{Boolean => JBool, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort, String => JString}
import java.math.BigDecimal
import java.sql.{Date, Time, Timestamp}
import java.util
import java.util.Calendar
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
object FlinkRelOptUtil {
/**
* Converts a relational expression to a string.
* This is different from [[RelOptUtil]]#toString on two points:
* 1. Generated string by this method is in a tree style
* 2. Generated string by this method may have more information about RelNode, such as
* resource, memory cost, RelNodeId, retractionTraits.
*
* @param rel the RelNode to convert
* @param detailLevel detailLevel defines detail levels for EXPLAIN PLAN.
* @param withRelNodeId whether including ID of RelNode
* @param withRetractTraits whether including Retraction Traits of RelNode (only apply to
* StreamPhysicalRel node at present)
* @return explain plan of RelNode
*/
def toString(
rel: RelNode,
detailLevel: SqlExplainLevel = SqlExplainLevel.EXPPLAN_ATTRIBUTES,
withRelNodeId: Boolean = false,
withRetractTraits: Boolean = false): String = {
val sw = new StringWriter
val planWriter = new RelTreeWriterImpl(
new PrintWriter(sw),
detailLevel,
withRelNodeId,
withRetractTraits)
rel.explain(planWriter)
sw.toString
}
def getDigest(rel: RelNode, withInput: Boolean = false): String = {
val sw: StringWriter = new StringWriter
val pw: RelWriter = new RelWriterImpl(
new PrintWriter(sw), SqlExplainLevel.DIGEST_ATTRIBUTES, false) {
override protected def explain_(
rel: RelNode, values: util.List[Pair[String, AnyRef]]): Unit = {
pw.write(rel.getRelTypeName)
pw.write("(")
var cnt = 0
values.foreach { value =>
value.right match {
case _: RelNode if !withInput => // ignore
case _ =>
if (cnt > 0) {
pw.write(", ")
}
pw.write(value.left + "=[" + value.right + "]")
cnt += 1
}
}
pw.write(")")
}
}
rel.explain(pw)
sw.toString
}
def getTableConfig(rel: RelNode): TableConfig = {
Option(rel.getCluster.getPlanner.getContext.unwrap(classOf[TableConfig]))
.getOrElse(TableConfig.DEFAULT)
}
def getFunctionCatalog(rel: RelNode): FunctionCatalog = {
Option(rel.getCluster.getPlanner.getContext.unwrap(classOf[FunctionCatalog]))
.getOrElse(BuiltInFunctionCatalog.withBuiltIns())
}
/**
* Get unique field name based on existed `allFieldNames` collection.
* NOTES: the new unique field name will be added to existed `allFieldNames` collection.
*/
def buildUniqueFieldName(
allFieldNames: mutable.Set[String],
toAddFieldName: String): String = {
var name: String = toAddFieldName
var i: Int = 0
while (allFieldNames.contains(name)) {
name = toAddFieldName + "_" + i
i += 1
}
allFieldNames.add(name)
name
}
/**
* Check whether AUXILIARY_GROUP aggCalls is in the front of the given agg's aggCallList,
* and whether aggCallList contain AUXILIARY_GROUP when the given agg's groupSet is empty
* or the indicator is true.
* Returns AUXILIARY_GROUP aggCalls' args and other aggCalls.
*
* @param agg aggregate
* @return returns AUXILIARY_GROUP aggCalls' args and other aggCalls
*/
def checkAndSplitAggCalls(agg: Aggregate): (Array[Int], Seq[AggregateCall]) = {
var nonAuxGroupCallsStartIdx = -1
val aggCalls = agg.getAggCallList
aggCalls.zipWithIndex.foreach {
case (call, idx) =>
if (call.getAggregation == SqlAuxiliaryGroupAggFunction) {
require(call.getArgList.size == 1)
}
if (nonAuxGroupCallsStartIdx >= 0) {
// the left aggCalls should not be AUXILIARY_GROUP
require(call.getAggregation != SqlAuxiliaryGroupAggFunction,
"AUXILIARY_GROUP should be in the front of aggCall list")
}
if (nonAuxGroupCallsStartIdx < 0 &&
call.getAggregation != SqlAuxiliaryGroupAggFunction) {
nonAuxGroupCallsStartIdx = idx
}
}
if (nonAuxGroupCallsStartIdx < 0) {
nonAuxGroupCallsStartIdx = aggCalls.length
}
val (auxGroupCalls, otherAggCalls) = aggCalls.splitAt(nonAuxGroupCallsStartIdx)
if (agg.getGroupCount == 0) {
require(auxGroupCalls.isEmpty,
"AUXILIARY_GROUP aggCalls should be empty when groupSet is empty")
}
if (agg.indicator) {
require(auxGroupCalls.isEmpty,
"AUXILIARY_GROUP aggCalls should be empty when indicator is true")
}
val auxGrouping = auxGroupCalls.map(_.getArgList.head.toInt).toArray
require(auxGrouping.length + otherAggCalls.length == aggCalls.length)
(auxGrouping, otherAggCalls)
}
def checkAndGetFullGroupSet(agg: Aggregate): Array[Int] = {
val (auxGroupSet, _) = checkAndSplitAggCalls(agg)
agg.getGroupSet.toArray ++ auxGroupSet
}
/** Get max cnf node limit by context of rel */
def getMaxCnfNodeCount(rel: RelNode): Int = {
getTableConfig(rel).getConf.getInteger(TableConfigOptions.SQL_OPTIMIZER_CNF_NODES_LIMIT)
}
/**
* Gets values of RexLiteral
*
* @param literal input RexLiteral
* @return values of the input RexLiteral
*/
def getLiteralValue(literal: RexLiteral): Comparable[_] = {
if (literal.isNull) {
null
} else {
val literalType = literal.getType
literalType.getSqlTypeName match {
case BOOLEAN => RexLiteral.booleanValue(literal)
case TINYINT => literal.getValueAs(classOf[JByte])
case SMALLINT => literal.getValueAs(classOf[JShort])
case INTEGER => literal.getValueAs(classOf[Integer])
case BIGINT => literal.getValueAs(classOf[JLong])
case FLOAT => literal.getValueAs(classOf[JFloat])
case DOUBLE => literal.getValueAs(classOf[JDouble])
case DECIMAL => literal.getValue3.asInstanceOf[BigDecimal]
case VARCHAR | CHAR => literal.getValueAs(classOf[JString])
// temporal types
case DATE =>
new Date(literal.getValueAs(classOf[Calendar]).getTimeInMillis)
case TIME =>
new Time(literal.getValueAs(classOf[Calendar]).getTimeInMillis)
case TIMESTAMP =>
new Timestamp(literal.getValueAs(classOf[Calendar]).getTimeInMillis)
case _ =>
throw new IllegalArgumentException(s"Literal type $literalType is not supported!")
}
}
}
/**
* Partitions the [[RexNode]] in two [[RexNode]] according to a predicate.
* The result is a pair of RexNode: the first RexNode consists of RexNode that satisfy the
* predicate and the second RexNode consists of RexNode that don't.
*
* For simple condition which is not AND, OR, NOT, it is completely satisfy the predicate or not.
*
* For complex condition Ands, partition each operands of ANDS recursively, then
* merge the RexNode which satisfy the predicate as the first part, merge the rest parts as the
* second part.
*
* For complex condition ORs, try to pull up common factors among ORs first, if the common
* factors is not A ORs, then simplify the question to partition the common factors expression;
* else the input condition is completely satisfy the predicate or not based on whether all
* its operands satisfy the predicate or not.
*
* For complex condition NOT, it is completely satisfy the predicate or not based on whether its
* operand satisfy the predicate or not.
*
* @param expr the expression to partition
* @param rexBuilder rexBuilder
* @param predicate the specified predicate on which to partition
* @return a pair of RexNode: the first RexNode consists of RexNode that satisfy the predicate
* and the second RexNode consists of RexNode that don't
*/
def partition(
expr: RexNode,
rexBuilder: RexBuilder,
predicate: RexNode => JBool): (Option[RexNode], Option[RexNode]) = {
val condition = pushNotToLeaf(expr, rexBuilder)
val (left: Option[RexNode], right: Option[RexNode]) = condition.getKind match {
case AND =>
val (leftExprs, rightExprs) = partition(
condition.asInstanceOf[RexCall].operands, rexBuilder, predicate)
if (leftExprs.isEmpty) {
(None, Option(condition))
} else {
val l = RexUtil.composeConjunction(rexBuilder, leftExprs.asJava, false)
if (rightExprs.isEmpty) {
(Option(l), None)
} else {
val r = RexUtil.composeConjunction(rexBuilder, rightExprs.asJava, false)
(Option(l), Option(r))
}
}
case OR =>
val e = RexUtil.pullFactors(rexBuilder, condition)
e.getKind match {
case OR =>
val (leftExprs, rightExprs) = partition(
condition.asInstanceOf[RexCall].operands, rexBuilder, predicate)
if (leftExprs.isEmpty || rightExprs.nonEmpty) {
(None, Option(condition))
} else {
val l = RexUtil.composeDisjunction(rexBuilder, leftExprs.asJava, false)
(Option(l), None)
}
case _ =>
partition(e, rexBuilder, predicate)
}
case NOT =>
val operand = condition.asInstanceOf[RexCall].operands.head
partition(operand, rexBuilder, predicate) match {
case (Some(_), None) => (Option(condition), None)
case (_, _) => (None, Option(condition))
}
case IS_TRUE =>
val operand = condition.asInstanceOf[RexCall].operands.head
partition(operand, rexBuilder, predicate)
case IS_FALSE =>
val operand = condition.asInstanceOf[RexCall].operands.head
val newCondition = pushNotToLeaf(operand, rexBuilder, needReverse = true)
partition(newCondition, rexBuilder, predicate)
case _ =>
if (predicate(condition)) {
(Option(condition), None)
} else {
(None, Option(condition))
}
}
(convertRexNodeIfAlwaysTrue(left), convertRexNodeIfAlwaysTrue(right))
}
private def partition(
exprs: Iterable[RexNode],
rexBuilder: RexBuilder,
predicate: RexNode => JBool): (Iterable[RexNode], Iterable[RexNode]) = {
val leftExprs = mutable.ListBuffer[RexNode]()
val rightExprs = mutable.ListBuffer[RexNode]()
exprs.foreach(expr => partition(expr, rexBuilder, predicate) match {
case (Some(first), Some(second)) =>
leftExprs += first
rightExprs += second
case (None, Some(rest)) =>
rightExprs += rest
case (Some(interested), None) =>
leftExprs += interested
})
(leftExprs, rightExprs)
}
private def convertRexNodeIfAlwaysTrue(expr: Option[RexNode]): Option[RexNode] = {
expr match {
case Some(rex) if rex.isAlwaysTrue => None
case _ => expr
}
}
private def pushNotToLeaf(expr: RexNode,
rexBuilder: RexBuilder,
needReverse: Boolean = false): RexNode = (expr.getKind, needReverse) match {
case (AND, true) | (OR, false) =>
val convertedExprs = expr.asInstanceOf[RexCall].operands
.map(pushNotToLeaf(_, rexBuilder, needReverse))
RexUtil.composeDisjunction(rexBuilder, convertedExprs, false)
case (AND, false) | (OR, true) =>
val convertedExprs = expr.asInstanceOf[RexCall].operands
.map(pushNotToLeaf(_, rexBuilder, needReverse))
RexUtil.composeConjunction(rexBuilder, convertedExprs, false)
case (NOT, _) =>
val child = expr.asInstanceOf[RexCall].operands.head
pushNotToLeaf(child, rexBuilder, !needReverse)
case (_, true) if expr.isInstanceOf[RexCall] =>
val negatedExpr = RexUtil.negate(rexBuilder, expr.asInstanceOf[RexCall])
if (negatedExpr != null) negatedExpr else RexUtil.not(expr)
case (_, true) => RexUtil.not(expr)
case (_, false) => expr
}
/**
* An RexVisitor to judge whether the RexNode is related to the specified index InputRef
*/
class ColumnRelatedVisitor(index: Int) extends RexVisitorImpl[JBool](true) {
override def visitInputRef(inputRef: RexInputRef): JBool = inputRef.getIndex == index
override def visitLiteral(literal: RexLiteral): JBool = true
override def visitCall(call: RexCall): JBool = {
call.operands.forall(operand => {
val isRelated = operand.accept(this)
isRelated != null && isRelated
})
}
}
}