/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.plan.util

import org.apache.calcite.plan.{RelOptPredicateList, RelOptUtil}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.sql.{SqlKind, SqlOperator}
import org.apache.calcite.util.{ControlFlowException, Util}

import com.google.common.base.Function
import com.google.common.collect.{ImmutableList, Lists}

import java.lang.Iterable
import java.util

import scala.collection.JavaConversions._
import scala.collection.mutable

/**
  * Utility methods concerning [[RexNode]].
  */
object FlinkRexUtil {

  /**
    * Similar to [[RexUtil#toCnf(RexBuilder, Int, RexNode)]]; it lets you
    * specify a threshold in the number of nodes that can be created out of
    * the conversion. however, if the threshold is a negative number,
    * this method will give a default threshold value that is double of
    * the number of RexCall in the given node.
    *
    * <p>If the number of resulting RexCalls exceeds that threshold,
    * stops conversion and returns the original expression.
    *
    * <p>Leaf nodes(e.g. RexInputRef) in the expression do not count towards the threshold.
    *
    * <p>We strongly discourage use the [[RexUtil#toCnf(RexBuilder, RexNode)]] and
    * [[RexUtil#toCnf(RexBuilder, Int, RexNode)]], because there are many bad case when using
    * [[RexUtil#toCnf(RexBuilder, RexNode)]], such as predicate in TPC-DS q41.sql will be
    * converted to extremely complex expression (including 736450 RexCalls); and we can not give
    * an appropriate value for `maxCnfNodeCount` when using
    * [[RexUtil#toCnf(RexBuilder, Int, RexNode)]].
    */
  def toCnf(rexBuilder: RexBuilder, maxCnfNodeCount: Int, rex: RexNode): RexNode = {
    val maxCnfNodeCnt = if (maxCnfNodeCount < 0) {
      getNumberOfRexCall(rex) * 2
    } else {
      maxCnfNodeCount
    }
    new CnfHelper(rexBuilder, maxCnfNodeCnt).toCnf(rex)
  }

  /**
    * Get the number of RexCall in the given node.
    */
  private def getNumberOfRexCall(rex: RexNode): Int = {
    var numberOfNodes = 0
    rex.accept(new RexVisitorImpl[Unit](true) {
      override def visitCall(call: RexCall): Unit = {
        numberOfNodes += 1
        super.visitCall(call)
      }
    })
    numberOfNodes
  }

  /** Helps [[toCnf]] */
  private class CnfHelper(rexBuilder: RexBuilder, maxNodeCount: Int) {

    /** Exception to catch when we pass the limit. */
    @SuppressWarnings(Array("serial"))
    private class OverflowError extends ControlFlowException {
    }

    @SuppressWarnings(Array("ThrowableInstanceNeverThrown"))
    private val INSTANCE = new OverflowError

    private val ADD_NOT = new Function[RexNode, RexNode]() {
      override def apply(input: RexNode): RexNode =
        rexBuilder.makeCall(input.getType, SqlStdOperatorTable.NOT, ImmutableList.of(input))
    }

    def toCnf(rex: RexNode): RexNode = try {
      toCnf2(rex)
    } catch {
      case e: OverflowError =>
        Util.swallow(e, null)
        rex
    }

    private def toCnf2(rex: RexNode): RexNode = {
      rex.getKind match {
        case SqlKind.AND =>
          val cnfOperands: util.List[RexNode] = Lists.newArrayList()
          val operands = RexUtil.flattenAnd(rex.asInstanceOf[RexCall].operands)
          operands.foreach { node =>
            val cnf = toCnf2(node)
            cnf.getKind match {
              case SqlKind.AND =>
                cnfOperands.addAll(cnf.asInstanceOf[RexCall].operands)
              case _ =>
                cnfOperands.add(cnf)
            }
          }
          val node = and(cnfOperands)
          checkCnfRexCallCount(node)
          node
        case SqlKind.OR =>
          val operands = RexUtil.flattenOr(rex.asInstanceOf[RexCall].operands)
          val head = operands.head
          val headCnf = toCnf2(head)
          val headCnfs: util.List[RexNode] = RelOptUtil.conjunctions(headCnf)
          val tail = or(Util.skip(operands))
          val tailCnf: RexNode = toCnf2(tail)
          val tailCnfs: util.List[RexNode] = RelOptUtil.conjunctions(tailCnf)
          val list: util.List[RexNode] = Lists.newArrayList()
          headCnfs.foreach { h =>
            tailCnfs.foreach {
              t => list.add(or(ImmutableList.of(h, t)))
            }
          }
          val node = and(list)
          checkCnfRexCallCount(node)
          node
        case SqlKind.NOT =>
          val arg = rex.asInstanceOf[RexCall].operands.head
          arg.getKind match {
            case SqlKind.NOT =>
              toCnf2(arg.asInstanceOf[RexCall].operands.head)
            case SqlKind.OR =>
              val operands = arg.asInstanceOf[RexCall].operands
              toCnf2(and(Lists.transform(RexUtil.flattenOr(operands), ADD_NOT)))
            case SqlKind.AND =>
              val operands = arg.asInstanceOf[RexCall].operands
              toCnf2(or(Lists.transform(RexUtil.flattenAnd(operands), ADD_NOT)))
            case _ => rex
          }
        case _ => rex
      }
    }

    private def checkCnfRexCallCount(node: RexNode): Unit = {
      // TODO use more efficient solution to get number of RexCall in CNF node
      if (maxNodeCount >= 0 && getNumberOfRexCall(node) > maxNodeCount) {
        throw INSTANCE
      }
    }

    private def and(nodes: Iterable[_ <: RexNode]): RexNode =
      RexUtil.composeConjunction(rexBuilder, nodes, false)

    private def or(nodes: Iterable[_ <: RexNode]): RexNode =
      RexUtil.composeDisjunction(rexBuilder, nodes)
  }

  /**
    * Merges same expressions and then simplifies the result expression by [[RexSimplify]].
    *
    * Examples for merging same expressions:
    * 1. a = b AND b = a -> a = b
    * 2. a = b OR b = a -> a = b
    * 3. (a > b AND c < 10) AND b < a -> a > b AND c < 10
    * 4. (a > b OR c < 10) OR b < a -> a > b OR c < 10
    * 5. a = a, a >= a, a <= a -> true
    * 6. a <> a, a > a, a < a -> false
    */
  def simplify(rexBuilder: RexBuilder, expr: RexNode): RexNode = {
    if (expr.isAlwaysTrue || expr.isAlwaysFalse) {
      return expr
    }

    val exprShuttle = new EquivalentExprShuttle(rexBuilder)
    val equiExpr = expr.accept(exprShuttle)
    val exprMerger = new SameExprMerger(rexBuilder)
    val sameExprMerged = exprMerger.mergeSameExpr(equiExpr)
    val binaryComparisonExprReduced = sameExprMerged.accept(
      new BinaryComparisonExprReducer(rexBuilder))

    val rexSimplify = new RexSimplify(rexBuilder, RelOptPredicateList.EMPTY, true, RexUtil.EXECUTOR)
    rexSimplify.simplify(binaryComparisonExprReduced)
  }

  private class BinaryComparisonExprReducer(rexBuilder: RexBuilder) extends RexShuttle {
    override def visitCall(call: RexCall): RexNode = {
      val kind = call.getOperator.getKind
      if (!kind.belongsTo(SqlKind.BINARY_COMPARISON)) {
        super.visitCall(call)
      } else {
        val operand0 = call.getOperands.get(0)
        val operand1 = call.getOperands.get(1)
        (operand0, operand1) match {
          case (op0: RexInputRef, op1: RexInputRef) if op0.getIndex == op1.getIndex =>
            kind match {
              case SqlKind.EQUALS | SqlKind.LESS_THAN_OR_EQUAL | SqlKind.GREATER_THAN_OR_EQUAL =>
                rexBuilder.makeLiteral(true)
              case SqlKind.NOT_EQUALS | SqlKind.LESS_THAN | SqlKind.GREATER_THAN =>
                rexBuilder.makeLiteral(false)
              case _ => super.visitCall(call)
          }
          case _ => super.visitCall(call)
        }
      }
    }
  }

  private class SameExprMerger(rexBuilder: RexBuilder) extends RexShuttle {
    private val sameExprMap = mutable.HashMap[String, RexNode]()

    private def mergeSameExpr(expr: RexNode, equiExpr: RexLiteral): RexNode = {
      if (sameExprMap.contains(expr.toString)) {
        equiExpr
      } else {
        sameExprMap.put(expr.toString, expr)
        expr
      }
    }

    def mergeSameExpr(expr: RexNode): RexNode = {
      // merges same expressions in the operands of AND and OR
      // e.g. a = b AND a = b -> a = b AND true
      //      a = b OR a = b -> a = b OR false
      val newExpr1 = expr.accept(this)

      // merges same expressions in conjunctions
      // e.g. (a > b AND c < 10) AND a > b -> a > b AND c < 10 AND true
      sameExprMap.clear()
      val newConjunctions = RelOptUtil.conjunctions(newExpr1).map {
        ex => mergeSameExpr(ex, rexBuilder.makeLiteral(true))
      }
      val newExpr2 = newConjunctions.size match {
        case 0 => newExpr1 // true AND true
        case 1 => newConjunctions.head
        case _ => rexBuilder.makeCall(AND, newConjunctions: _*)
      }

      // merges same expressions in disjunctions
      // e.g. (a > b OR c < 10) OR a > b -> a > b OR c < 10 OR false
      sameExprMap.clear()
      val newDisjunctions = RelOptUtil.disjunctions(newExpr2).map {
        ex => mergeSameExpr(ex, rexBuilder.makeLiteral(false))
      }
      val newExpr3 = newDisjunctions.size match {
        case 0 => newExpr2 // false OR false
        case 1 => newDisjunctions.head
        case _ => rexBuilder.makeCall(OR, newDisjunctions: _*)
      }
      newExpr3
    }

    override def visitCall(call: RexCall): RexNode = {
      val newCall = call.getOperator match {
        case AND | OR =>
          sameExprMap.clear()
          val newOperands = call.getOperands.map {
            op =>
              val value = if (call.getOperator == AND) true else false
              mergeSameExpr(op, rexBuilder.makeLiteral(value))
          }
          call.clone(call.getType, newOperands)
        case _ => call
      }
      super.visitCall(newCall)
    }
  }

  /**
    * Adjust the condition's field indices according to mapOldToNewIndex.
    *
    * @param c The condition to be adjusted.
    * @param fieldsOldToNewIndexMapping A map containing the mapping the old field indices to new
    *   field indices.
    * @param rowType The row type of the new output.
    * @return Return new condition with new field indices.
    */
  private[flink] def adjustInputRefs(
      c: RexNode,
      fieldsOldToNewIndexMapping: Map[Int, Int],
      rowType: RelDataType) = c.accept(
    new RexShuttle() {

      override def visitInputRef(inputRef: RexInputRef): RexNode = {
        assert(fieldsOldToNewIndexMapping.containsKey(inputRef.getIndex))
        val newIndex = fieldsOldToNewIndexMapping(inputRef.getIndex)
        val ref = RexInputRef.of(newIndex, rowType)
        if (ref.getIndex == inputRef.getIndex && (ref.getType eq inputRef.getType)) {
          inputRef
        } else {
           // re-use old object, to prevent needless expr cloning
          ref
        }
      }
    })

  private class EquivalentExprShuttle(rexBuilder: RexBuilder) extends RexShuttle {
    private val equiExprMap = mutable.HashMap[String, RexNode]()

    override def visitCall(call: RexCall): RexNode = {
      call.getOperator match {
        case EQUALS | NOT_EQUALS | GREATER_THAN | LESS_THAN |
             GREATER_THAN_OR_EQUAL | LESS_THAN_OR_EQUAL =>
          val swapped = swapOperands(call)
          if (equiExprMap.contains(swapped.toString)) {
            swapped
          } else {
            equiExprMap.put(call.toString, call)
            call
          }
        case _ => super.visitCall(call)
      }
    }

    private def swapOperands(call: RexCall): RexCall = {
      val newOp = call.getOperator match {
        case EQUALS | NOT_EQUALS => call.getOperator
        case GREATER_THAN => LESS_THAN
        case GREATER_THAN_OR_EQUAL => LESS_THAN_OR_EQUAL
        case LESS_THAN => GREATER_THAN
        case LESS_THAN_OR_EQUAL => GREATER_THAN_OR_EQUAL
        case _ => throw new IllegalArgumentException(s"Unsupported operator: ${call.getOperator}")
      }
      val operands = call.getOperands
      rexBuilder.makeCall(newOp, operands.last, operands.head).asInstanceOf[RexCall]
    }
  }

  /**
    * Returns whether a given expression has dynamic function.
    *
    * @param e Expression
    * @return true if tree has dynamic function, false otherwise
    */
  def hasDynamicFunction(e: RexNode): Boolean = try {
    val visitor = new RexVisitorImpl[Void](true) {
      override def visitCall(call: RexCall): Void = {
        if (call.getOperator.isDynamicFunction) {
          throw Util.FoundOne.NULL
        }
        super.visitCall(call)
      }
    }
    e.accept(visitor)
    false
  } catch {
    case ex: Util.FoundOne =>
      Util.swallow(ex, null)
      true
  }

  /**
    * Return true if the given RexNode is null or does not have
    * non-deterministic `SqlOperator` and dynamic function `SqlOperator`.
    */
  def isDeterministicOperator(rex: RexNode): Boolean = {
    if (rex == null) {
      true
    } else {
      RexUtil.isDeterministic(rex) && !FlinkRexUtil.hasDynamicFunction(rex)
    }
  }

  /**
    * Return true if the given operator is null or is deterministic and none dynamic function.
    */
  def isDeterministicOperator(op: SqlOperator): Boolean = {
    if (op == null) {
      true
    } else {
      op.isDeterministic && !op.isDynamicFunction
    }
  }
}
