blob: 0ffd972e5f1f504aac86694e19233380603a1a37 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.expressions
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.plan.logical.LogicalExprVisitor
import org.apache.flink.table.api.types.{DataTypes, DecimalType, InternalType}
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.table.typeutils.TypeCoercion
import org.apache.flink.table.validate._
import scala.collection.JavaConversions._
abstract class BinaryArithmetic extends BinaryExpression {
private[flink] def sqlOperator: SqlOperator
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(sqlOperator, children.map(_.toRexNode))
}
override private[flink] def resultType: InternalType =
TypeCoercion.widerTypeOf(left.resultType, right.resultType) match {
case Some(t) => t
case None =>
throw new RuntimeException("This should never happen.")
}
override private[flink] def validateInput(): ValidationResult = {
if (!isNumeric(left.resultType) ||
!isNumeric(right.resultType)) {
ValidationFailure(s"$this requires both operands to be numeric, but was " +
s"$left : ${left.resultType} and $right : ${right.resultType}")
} else {
ValidationSuccess
}
}
}
case class Plus(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left + $right)"
private[flink] val sqlOperator = SqlStdOperatorTable.PLUS
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
if(isString(left.resultType)) {
val castedRight = Cast(right, DataTypes.STRING)
relBuilder.call(SqlStdOperatorTable.CONCAT, left.toRexNode, castedRight.toRexNode)
} else if(isString(right.resultType)) {
val castedLeft = Cast(left, DataTypes.STRING)
relBuilder.call(SqlStdOperatorTable.CONCAT, castedLeft.toRexNode, right.toRexNode)
} else if (isTimeInterval(left.resultType) &&
left.resultType == right.resultType) {
relBuilder.call(SqlStdOperatorTable.PLUS, left.toRexNode, right.toRexNode)
} else if (isTimeInterval(left.resultType)
&& isTemporal(right.resultType)) {
// Calcite has a bug that can't apply INTERVAL + DATETIME (INTERVAL at left)
// we manually switch them here
relBuilder.call(SqlStdOperatorTable.DATETIME_PLUS, right.toRexNode, left.toRexNode)
} else if (isTemporal(left.resultType) &&
isTemporal(right.resultType)) {
relBuilder.call(SqlStdOperatorTable.DATETIME_PLUS, left.toRexNode, right.toRexNode)
} else {
super.toRexNode
}
}
override private[flink] def validateInput(): ValidationResult = {
if (isString(left.resultType) ||
isString(right.resultType)) {
ValidationSuccess
} else if (isTimeInterval(left.resultType) &&
left.resultType == right.resultType) {
ValidationSuccess
} else if (isTimePoint(left.resultType) &&
isTimeInterval(right.resultType)) {
ValidationSuccess
} else if (isTimeInterval(left.resultType) &&
isTimePoint(right.resultType)) {
ValidationSuccess
} else if (isNumeric(left.resultType) &&
isNumeric(right.resultType)) {
ValidationSuccess
} else {
ValidationFailure(
s"$this requires Numeric, String, Intervals of same type, " +
s"or Interval and a time point input, " +
s"get $left : ${left.resultType} and $right : ${right.resultType}")
}
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class UnaryMinus(child: Expression) extends UnaryExpression {
override def toString = s"-($child)"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(SqlStdOperatorTable.UNARY_MINUS, child.toRexNode)
}
override private[flink] def resultType = child.resultType
override private[flink] def validateInput(): ValidationResult = {
if (isNumeric(child.resultType)) {
ValidationSuccess
} else if (isTimeInterval(child.resultType)) {
ValidationSuccess
} else {
ValidationFailure(s"$this requires Numeric, or Interval input, get ${child.resultType}")
}
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class Minus(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left - $right)"
private[flink] val sqlOperator = SqlStdOperatorTable.MINUS
override private[flink] def validateInput(): ValidationResult = {
if (isTimeInterval(left.resultType) &&
left.resultType == right.resultType) {
ValidationSuccess
} else if (isTimePoint(left.resultType) &&
isTimeInterval(right.resultType)) {
ValidationSuccess
} else if (isTimeInterval(left.resultType) &&
isTimePoint(right.resultType)) {
ValidationSuccess
} else {
super.validateInput()
}
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class Div(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left / $right)"
private[flink] val sqlOperator = ScalarSqlFunctions.DIVIDE
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
override private[flink] def resultType: InternalType =
super.resultType match {
case dt: DecimalType => dt
case _ => DataTypes.DOUBLE
}
}
case class Mul(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left * $right)"
private[flink] val sqlOperator = SqlStdOperatorTable.MULTIPLY
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class Mod(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left % $right)"
private[flink] val sqlOperator = SqlStdOperatorTable.MOD
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}