| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.flink.table.plan.util |
| |
| import org.apache.calcite.plan.{RelOptPredicateList, RelOptUtil} |
| import org.apache.calcite.rel.`type`.RelDataType |
| import org.apache.calcite.rex._ |
| import org.apache.calcite.sql.fun.SqlStdOperatorTable |
| import org.apache.calcite.sql.fun.SqlStdOperatorTable._ |
| import org.apache.calcite.sql.{SqlKind, SqlOperator} |
| import org.apache.calcite.util.{ControlFlowException, Util} |
| |
| import com.google.common.base.Function |
| import com.google.common.collect.{ImmutableList, Lists} |
| |
| import java.lang.Iterable |
| import java.util |
| |
| import scala.collection.JavaConversions._ |
| import scala.collection.mutable |
| |
| /** |
| * Utility methods concerning [[RexNode]]. |
| */ |
| object FlinkRexUtil { |
| |
| /** |
| * Similar to [[RexUtil#toCnf(RexBuilder, Int, RexNode)]]; it lets you |
| * specify a threshold in the number of nodes that can be created out of |
| * the conversion. however, if the threshold is a negative number, |
| * this method will give a default threshold value that is double of |
| * the number of RexCall in the given node. |
| * |
| * <p>If the number of resulting RexCalls exceeds that threshold, |
| * stops conversion and returns the original expression. |
| * |
| * <p>Leaf nodes(e.g. RexInputRef) in the expression do not count towards the threshold. |
| * |
| * <p>We strongly discourage use the [[RexUtil#toCnf(RexBuilder, RexNode)]] and |
| * [[RexUtil#toCnf(RexBuilder, Int, RexNode)]], because there are many bad case when using |
| * [[RexUtil#toCnf(RexBuilder, RexNode)]], such as predicate in TPC-DS q41.sql will be |
| * converted to extremely complex expression (including 736450 RexCalls); and we can not give |
| * an appropriate value for `maxCnfNodeCount` when using |
| * [[RexUtil#toCnf(RexBuilder, Int, RexNode)]]. |
| */ |
| def toCnf(rexBuilder: RexBuilder, maxCnfNodeCount: Int, rex: RexNode): RexNode = { |
| val maxCnfNodeCnt = if (maxCnfNodeCount < 0) { |
| getNumberOfRexCall(rex) * 2 |
| } else { |
| maxCnfNodeCount |
| } |
| new CnfHelper(rexBuilder, maxCnfNodeCnt).toCnf(rex) |
| } |
| |
| /** |
| * Get the number of RexCall in the given node. |
| */ |
| private def getNumberOfRexCall(rex: RexNode): Int = { |
| var numberOfNodes = 0 |
| rex.accept(new RexVisitorImpl[Unit](true) { |
| override def visitCall(call: RexCall): Unit = { |
| numberOfNodes += 1 |
| super.visitCall(call) |
| } |
| }) |
| numberOfNodes |
| } |
| |
| /** Helps [[toCnf]] */ |
| private class CnfHelper(rexBuilder: RexBuilder, maxNodeCount: Int) { |
| |
| /** Exception to catch when we pass the limit. */ |
| @SuppressWarnings(Array("serial")) |
| private class OverflowError extends ControlFlowException { |
| } |
| |
| @SuppressWarnings(Array("ThrowableInstanceNeverThrown")) |
| private val INSTANCE = new OverflowError |
| |
| private val ADD_NOT = new Function[RexNode, RexNode]() { |
| override def apply(input: RexNode): RexNode = |
| rexBuilder.makeCall(input.getType, SqlStdOperatorTable.NOT, ImmutableList.of(input)) |
| } |
| |
| def toCnf(rex: RexNode): RexNode = try { |
| toCnf2(rex) |
| } catch { |
| case e: OverflowError => |
| Util.swallow(e, null) |
| rex |
| } |
| |
| private def toCnf2(rex: RexNode): RexNode = { |
| rex.getKind match { |
| case SqlKind.AND => |
| val cnfOperands: util.List[RexNode] = Lists.newArrayList() |
| val operands = RexUtil.flattenAnd(rex.asInstanceOf[RexCall].operands) |
| operands.foreach { node => |
| val cnf = toCnf2(node) |
| cnf.getKind match { |
| case SqlKind.AND => |
| cnfOperands.addAll(cnf.asInstanceOf[RexCall].operands) |
| case _ => |
| cnfOperands.add(cnf) |
| } |
| } |
| val node = and(cnfOperands) |
| checkCnfRexCallCount(node) |
| node |
| case SqlKind.OR => |
| val operands = RexUtil.flattenOr(rex.asInstanceOf[RexCall].operands) |
| val head = operands.head |
| val headCnf = toCnf2(head) |
| val headCnfs: util.List[RexNode] = RelOptUtil.conjunctions(headCnf) |
| val tail = or(Util.skip(operands)) |
| val tailCnf: RexNode = toCnf2(tail) |
| val tailCnfs: util.List[RexNode] = RelOptUtil.conjunctions(tailCnf) |
| val list: util.List[RexNode] = Lists.newArrayList() |
| headCnfs.foreach { h => |
| tailCnfs.foreach { |
| t => list.add(or(ImmutableList.of(h, t))) |
| } |
| } |
| val node = and(list) |
| checkCnfRexCallCount(node) |
| node |
| case SqlKind.NOT => |
| val arg = rex.asInstanceOf[RexCall].operands.head |
| arg.getKind match { |
| case SqlKind.NOT => |
| toCnf2(arg.asInstanceOf[RexCall].operands.head) |
| case SqlKind.OR => |
| val operands = arg.asInstanceOf[RexCall].operands |
| toCnf2(and(Lists.transform(RexUtil.flattenOr(operands), ADD_NOT))) |
| case SqlKind.AND => |
| val operands = arg.asInstanceOf[RexCall].operands |
| toCnf2(or(Lists.transform(RexUtil.flattenAnd(operands), ADD_NOT))) |
| case _ => rex |
| } |
| case _ => rex |
| } |
| } |
| |
| private def checkCnfRexCallCount(node: RexNode): Unit = { |
| // TODO use more efficient solution to get number of RexCall in CNF node |
| if (maxNodeCount >= 0 && getNumberOfRexCall(node) > maxNodeCount) { |
| throw INSTANCE |
| } |
| } |
| |
| private def and(nodes: Iterable[_ <: RexNode]): RexNode = |
| RexUtil.composeConjunction(rexBuilder, nodes, false) |
| |
| private def or(nodes: Iterable[_ <: RexNode]): RexNode = |
| RexUtil.composeDisjunction(rexBuilder, nodes) |
| } |
| |
| /** |
| * Merges same expressions and then simplifies the result expression by [[RexSimplify]]. |
| * |
| * Examples for merging same expressions: |
| * 1. a = b AND b = a -> a = b |
| * 2. a = b OR b = a -> a = b |
| * 3. (a > b AND c < 10) AND b < a -> a > b AND c < 10 |
| * 4. (a > b OR c < 10) OR b < a -> a > b OR c < 10 |
| * 5. a = a, a >= a, a <= a -> true |
| * 6. a <> a, a > a, a < a -> false |
| */ |
| def simplify(rexBuilder: RexBuilder, expr: RexNode): RexNode = { |
| if (expr.isAlwaysTrue || expr.isAlwaysFalse) { |
| return expr |
| } |
| |
| val exprShuttle = new EquivalentExprShuttle(rexBuilder) |
| val equiExpr = expr.accept(exprShuttle) |
| val exprMerger = new SameExprMerger(rexBuilder) |
| val sameExprMerged = exprMerger.mergeSameExpr(equiExpr) |
| val binaryComparisonExprReduced = sameExprMerged.accept( |
| new BinaryComparisonExprReducer(rexBuilder)) |
| |
| val rexSimplify = new RexSimplify(rexBuilder, RelOptPredicateList.EMPTY, true, RexUtil.EXECUTOR) |
| rexSimplify.simplify(binaryComparisonExprReduced) |
| } |
| |
| private class BinaryComparisonExprReducer(rexBuilder: RexBuilder) extends RexShuttle { |
| override def visitCall(call: RexCall): RexNode = { |
| val kind = call.getOperator.getKind |
| if (!kind.belongsTo(SqlKind.BINARY_COMPARISON)) { |
| super.visitCall(call) |
| } else { |
| val operand0 = call.getOperands.get(0) |
| val operand1 = call.getOperands.get(1) |
| (operand0, operand1) match { |
| case (op0: RexInputRef, op1: RexInputRef) if op0.getIndex == op1.getIndex => |
| kind match { |
| case SqlKind.EQUALS | SqlKind.LESS_THAN_OR_EQUAL | SqlKind.GREATER_THAN_OR_EQUAL => |
| rexBuilder.makeLiteral(true) |
| case SqlKind.NOT_EQUALS | SqlKind.LESS_THAN | SqlKind.GREATER_THAN => |
| rexBuilder.makeLiteral(false) |
| case _ => super.visitCall(call) |
| } |
| case _ => super.visitCall(call) |
| } |
| } |
| } |
| } |
| |
| private class SameExprMerger(rexBuilder: RexBuilder) extends RexShuttle { |
| private val sameExprMap = mutable.HashMap[String, RexNode]() |
| |
| private def mergeSameExpr(expr: RexNode, equiExpr: RexLiteral): RexNode = { |
| if (sameExprMap.contains(expr.toString)) { |
| equiExpr |
| } else { |
| sameExprMap.put(expr.toString, expr) |
| expr |
| } |
| } |
| |
| def mergeSameExpr(expr: RexNode): RexNode = { |
| // merges same expressions in the operands of AND and OR |
| // e.g. a = b AND a = b -> a = b AND true |
| // a = b OR a = b -> a = b OR false |
| val newExpr1 = expr.accept(this) |
| |
| // merges same expressions in conjunctions |
| // e.g. (a > b AND c < 10) AND a > b -> a > b AND c < 10 AND true |
| sameExprMap.clear() |
| val newConjunctions = RelOptUtil.conjunctions(newExpr1).map { |
| ex => mergeSameExpr(ex, rexBuilder.makeLiteral(true)) |
| } |
| val newExpr2 = newConjunctions.size match { |
| case 0 => newExpr1 // true AND true |
| case 1 => newConjunctions.head |
| case _ => rexBuilder.makeCall(AND, newConjunctions: _*) |
| } |
| |
| // merges same expressions in disjunctions |
| // e.g. (a > b OR c < 10) OR a > b -> a > b OR c < 10 OR false |
| sameExprMap.clear() |
| val newDisjunctions = RelOptUtil.disjunctions(newExpr2).map { |
| ex => mergeSameExpr(ex, rexBuilder.makeLiteral(false)) |
| } |
| val newExpr3 = newDisjunctions.size match { |
| case 0 => newExpr2 // false OR false |
| case 1 => newDisjunctions.head |
| case _ => rexBuilder.makeCall(OR, newDisjunctions: _*) |
| } |
| newExpr3 |
| } |
| |
| override def visitCall(call: RexCall): RexNode = { |
| val newCall = call.getOperator match { |
| case AND | OR => |
| sameExprMap.clear() |
| val newOperands = call.getOperands.map { |
| op => |
| val value = if (call.getOperator == AND) true else false |
| mergeSameExpr(op, rexBuilder.makeLiteral(value)) |
| } |
| call.clone(call.getType, newOperands) |
| case _ => call |
| } |
| super.visitCall(newCall) |
| } |
| } |
| |
| /** |
| * Adjust the condition's field indices according to mapOldToNewIndex. |
| * |
| * @param c The condition to be adjusted. |
| * @param fieldsOldToNewIndexMapping A map containing the mapping the old field indices to new |
| * field indices. |
| * @param rowType The row type of the new output. |
| * @return Return new condition with new field indices. |
| */ |
| private[flink] def adjustInputRefs( |
| c: RexNode, |
| fieldsOldToNewIndexMapping: Map[Int, Int], |
| rowType: RelDataType) = c.accept( |
| new RexShuttle() { |
| |
| override def visitInputRef(inputRef: RexInputRef): RexNode = { |
| assert(fieldsOldToNewIndexMapping.containsKey(inputRef.getIndex)) |
| val newIndex = fieldsOldToNewIndexMapping(inputRef.getIndex) |
| val ref = RexInputRef.of(newIndex, rowType) |
| if (ref.getIndex == inputRef.getIndex && (ref.getType eq inputRef.getType)) { |
| inputRef |
| } else { |
| // re-use old object, to prevent needless expr cloning |
| ref |
| } |
| } |
| }) |
| |
| private class EquivalentExprShuttle(rexBuilder: RexBuilder) extends RexShuttle { |
| private val equiExprMap = mutable.HashMap[String, RexNode]() |
| |
| override def visitCall(call: RexCall): RexNode = { |
| call.getOperator match { |
| case EQUALS | NOT_EQUALS | GREATER_THAN | LESS_THAN | |
| GREATER_THAN_OR_EQUAL | LESS_THAN_OR_EQUAL => |
| val swapped = swapOperands(call) |
| if (equiExprMap.contains(swapped.toString)) { |
| swapped |
| } else { |
| equiExprMap.put(call.toString, call) |
| call |
| } |
| case _ => super.visitCall(call) |
| } |
| } |
| |
| private def swapOperands(call: RexCall): RexCall = { |
| val newOp = call.getOperator match { |
| case EQUALS | NOT_EQUALS => call.getOperator |
| case GREATER_THAN => LESS_THAN |
| case GREATER_THAN_OR_EQUAL => LESS_THAN_OR_EQUAL |
| case LESS_THAN => GREATER_THAN |
| case LESS_THAN_OR_EQUAL => GREATER_THAN_OR_EQUAL |
| case _ => throw new IllegalArgumentException(s"Unsupported operator: ${call.getOperator}") |
| } |
| val operands = call.getOperands |
| rexBuilder.makeCall(newOp, operands.last, operands.head).asInstanceOf[RexCall] |
| } |
| } |
| |
| /** |
| * Returns whether a given expression has dynamic function. |
| * |
| * @param e Expression |
| * @return true if tree has dynamic function, false otherwise |
| */ |
| def hasDynamicFunction(e: RexNode): Boolean = try { |
| val visitor = new RexVisitorImpl[Void](true) { |
| override def visitCall(call: RexCall): Void = { |
| if (call.getOperator.isDynamicFunction) { |
| throw Util.FoundOne.NULL |
| } |
| super.visitCall(call) |
| } |
| } |
| e.accept(visitor) |
| false |
| } catch { |
| case ex: Util.FoundOne => |
| Util.swallow(ex, null) |
| true |
| } |
| |
| /** |
| * Return true if the given RexNode is null or does not have |
| * non-deterministic `SqlOperator` and dynamic function `SqlOperator`. |
| */ |
| def isDeterministicOperator(rex: RexNode): Boolean = { |
| if (rex == null) { |
| true |
| } else { |
| RexUtil.isDeterministic(rex) && !FlinkRexUtil.hasDynamicFunction(rex) |
| } |
| } |
| |
| /** |
| * Return true if the given operator is null or is deterministic and none dynamic function. |
| */ |
| def isDeterministicOperator(op: SqlOperator): Boolean = { |
| if (op == null) { |
| true |
| } else { |
| op.isDeterministic && !op.isDynamicFunction |
| } |
| } |
| } |