blob: 95ca6529588d04ed3dc2e8489e0c0a6011705be3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.expressions
import org.apache.calcite.avatica.util.TimeUnit
import org.apache.calcite.rex._
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.api.types.{DataTypes, InternalType, TimestampType}
import org.apache.flink.table.calcite.FlinkRelBuilder
import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.plan.logical.LogicalExprVisitor
import org.apache.flink.table.typeutils.TypeCheckUtils
import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
import scala.collection.JavaConversions._
case class Extract(timeIntervalUnit: Expression, temporal: Expression) extends Expression {
override private[flink] def children: Seq[Expression] = timeIntervalUnit :: temporal :: Nil
override private[flink] def resultType: InternalType = DataTypes.LONG
override private[flink] def validateInput(): ValidationResult = {
if (!TypeCheckUtils.isTemporal(temporal.resultType)) {
return ValidationFailure(s"Extract operator requires Temporal input, " +
s"but $temporal is of type ${temporal.resultType}")
}
timeIntervalUnit match {
case SymbolExpression(TimeIntervalUnit.YEAR)
| SymbolExpression(TimeIntervalUnit.QUARTER)
| SymbolExpression(TimeIntervalUnit.MONTH)
| SymbolExpression(TimeIntervalUnit.WEEK)
| SymbolExpression(TimeIntervalUnit.DAY)
if temporal.resultType == DataTypes.DATE
|| temporal.resultType == DataTypes.TIMESTAMP
|| temporal.resultType == DataTypes.INTERVAL_MILLIS
|| temporal.resultType == DataTypes.INTERVAL_MONTHS =>
ValidationSuccess
case SymbolExpression(TimeIntervalUnit.HOUR)
| SymbolExpression(TimeIntervalUnit.MINUTE)
| SymbolExpression(TimeIntervalUnit.SECOND)
if temporal.resultType == DataTypes.TIME
|| temporal.resultType == DataTypes.TIMESTAMP
|| temporal.resultType == DataTypes.INTERVAL_MILLIS =>
ValidationSuccess
case _ =>
ValidationFailure(s"Extract operator does not support unit '$timeIntervalUnit' for input" +
s" of type '${temporal.resultType}'.")
}
}
override def toString: String = s"($temporal).extract($timeIntervalUnit)"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
relBuilder
.getRexBuilder
.makeCall(
SqlStdOperatorTable.EXTRACT,
Seq(timeIntervalUnit.toRexNode, temporal.toRexNode))
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
abstract class TemporalCeilFloor(
timeIntervalUnit: Expression,
temporal: Expression)
extends Expression {
override private[flink] def children: Seq[Expression] = timeIntervalUnit :: temporal :: Nil
override private[flink] def resultType: InternalType = temporal.resultType
override private[flink] def validateInput(): ValidationResult = {
if (!TypeCheckUtils.isTimePoint(temporal.resultType)) {
return ValidationFailure(s"Temporal ceil/floor operator requires Time Point input, " +
s"but $temporal is of type ${temporal.resultType}")
}
val unit = timeIntervalUnit match {
case SymbolExpression(u: TimeIntervalUnit) => Some(u)
case _ => None
}
if (unit.isEmpty) {
return ValidationFailure(s"Temporal ceil/floor operator requires Time Interval Unit " +
s"input, but $timeIntervalUnit is of type ${timeIntervalUnit.resultType}")
}
(unit.get, temporal.resultType) match {
case (TimeIntervalUnit.YEAR | TimeIntervalUnit.MONTH,
DataTypes.DATE | _: TimestampType) =>
ValidationSuccess
case (TimeIntervalUnit.DAY, _: TimestampType) =>
ValidationSuccess
case (TimeIntervalUnit.HOUR | TimeIntervalUnit.MINUTE | TimeIntervalUnit.SECOND,
DataTypes.TIME | _: TimestampType) =>
ValidationSuccess
case _ =>
ValidationFailure(s"Temporal ceil/floor operator does not support " +
s"unit '$timeIntervalUnit' for input of type '${temporal.resultType}'.")
}
}
}
case class TemporalFloor(
timeIntervalUnit: Expression,
temporal: Expression)
extends TemporalCeilFloor(
timeIntervalUnit,
temporal) {
override def toString: String = s"($temporal).floor($timeIntervalUnit)"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(SqlStdOperatorTable.FLOOR, temporal.toRexNode, timeIntervalUnit.toRexNode)
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class TemporalCeil(
timeIntervalUnit: Expression,
temporal: Expression)
extends TemporalCeilFloor(
timeIntervalUnit,
temporal) {
override def toString: String = s"($temporal).ceil($timeIntervalUnit)"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(SqlStdOperatorTable.CEIL, temporal.toRexNode, timeIntervalUnit.toRexNode)
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
abstract class CurrentTimePoint(
targetType: InternalType,
local: Boolean)
extends LeafExpression {
override private[flink] def resultType: InternalType = targetType
override private[flink] def validateInput(): ValidationResult = {
if (!TypeCheckUtils.isTimePoint(targetType)) {
ValidationFailure(s"CurrentTimePoint operator requires Time Point target type, " +
s"but get $targetType.")
} else if (local && targetType == DataTypes.DATE) {
ValidationFailure(s"Localized CurrentTimePoint operator requires Time or Timestamp target " +
s"type, but get $targetType.")
} else {
ValidationSuccess
}
}
override def toString: String = if (local) {
s"local$targetType()"
} else {
s"current$targetType()"
}
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val operator = targetType match {
case DataTypes.TIME if local => SqlStdOperatorTable.LOCALTIME
case DataTypes.TIMESTAMP if local => SqlStdOperatorTable.LOCALTIMESTAMP
case DataTypes.DATE => SqlStdOperatorTable.CURRENT_DATE
case DataTypes.TIME => SqlStdOperatorTable.CURRENT_TIME
case DataTypes.TIMESTAMP => SqlStdOperatorTable.CURRENT_TIMESTAMP
}
relBuilder.call(operator)
}
}
case class CurrentDate() extends CurrentTimePoint(DataTypes.DATE, local = false) {
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class CurrentTime() extends CurrentTimePoint(DataTypes.TIME, local = false) {
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class CurrentTimestamp() extends CurrentTimePoint(DataTypes.TIMESTAMP, local = false) {
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class LocalTime() extends CurrentTimePoint(DataTypes.TIME, local = true) {
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class LocalTimestamp() extends CurrentTimePoint(DataTypes.TIMESTAMP, local = true) {
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
/**
* * Extracts the quarter of a year from a SQL date.
*/
case class Quarter(child: Expression) extends UnaryExpression with InputTypeSpec {
override private[flink] def expectedTypes: Seq[InternalType] = Seq(DataTypes.DATE)
override private[flink] def resultType: InternalType = DataTypes.LONG
override def toString: String = s"($child).quarter()"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
/**
* Standard conversion of the QUARTER operator.
* Source: [[org.apache.calcite.sql2rel.StandardConvertletTable#convertQuarter()]]
*/
Plus(
Cast(
Div(
Minus(
Extract(TimeIntervalUnit.MONTH, child),
Literal(1L)),
Literal(TimeUnit.QUARTER.multiplier.longValue())),
DataTypes.INT
),
Literal(1L)
).toRexNode
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
/**
* Determines whether two anchored time intervals overlap.
*/
case class TemporalOverlaps(
leftTimePoint: Expression,
leftTemporal: Expression,
rightTimePoint: Expression,
rightTemporal: Expression)
extends Expression {
override private[flink] def children: Seq[Expression] =
Seq(leftTimePoint, leftTemporal, rightTimePoint, rightTemporal)
override private[flink] def resultType: InternalType = DataTypes.BOOLEAN
override private[flink] def validateInput(): ValidationResult = {
if (!TypeCheckUtils.isTimePoint(leftTimePoint.resultType)) {
return ValidationFailure(s"TemporalOverlaps operator requires leftTimePoint to be of type " +
s"Time Point, but get ${leftTimePoint.resultType}.")
}
if (!TypeCheckUtils.isTimePoint(rightTimePoint.resultType)) {
return ValidationFailure(s"TemporalOverlaps operator requires rightTimePoint to be of " +
s"type Time Point, but get ${rightTimePoint.resultType}.")
}
if (leftTimePoint.resultType != rightTimePoint.resultType) {
return ValidationFailure(s"TemporalOverlaps operator requires leftTimePoint and " +
s"rightTimePoint to be of same type.")
}
// leftTemporal is point, then it must be comparable with leftTimePoint
if (TypeCheckUtils.isTimePoint(leftTemporal.resultType)) {
if (leftTemporal.resultType != leftTimePoint.resultType) {
return ValidationFailure(s"TemporalOverlaps operator requires leftTemporal and " +
s"leftTimePoint to be of same type if leftTemporal is of type Time Point.")
}
} else if (!isTimeInterval(leftTemporal.resultType)) {
return ValidationFailure(s"TemporalOverlaps operator requires leftTemporal to be of " +
s"type Time Point or Time Interval.")
}
// rightTemporal is point, then it must be comparable with rightTimePoint
if (TypeCheckUtils.isTimePoint(rightTemporal.resultType)) {
if (rightTemporal.resultType != rightTimePoint.resultType) {
return ValidationFailure(s"TemporalOverlaps operator requires rightTemporal and " +
s"rightTimePoint to be of same type if rightTemporal is of type Time Point.")
}
} else if (!isTimeInterval(rightTemporal.resultType)) {
return ValidationFailure(s"TemporalOverlaps operator requires rightTemporal to be of " +
s"type Time Point or Time Interval.")
}
ValidationSuccess
}
override def toString: String = s"temporalOverlaps(${children.mkString(", ")})"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
convertOverlaps(
leftTimePoint.toRexNode,
leftTemporal.toRexNode,
rightTimePoint.toRexNode,
rightTemporal.toRexNode,
relBuilder.asInstanceOf[FlinkRelBuilder])
}
/**
* Standard conversion of the OVERLAPS operator.
* Source: [[org.apache.calcite.sql2rel.StandardConvertletTable#convertOverlaps()]]
*/
private def convertOverlaps(
leftP: RexNode,
leftT: RexNode,
rightP: RexNode,
rightT: RexNode,
relBuilder: FlinkRelBuilder)
: RexNode = {
val convLeftT = convertOverlapsEnd(relBuilder, leftP, leftT, leftTemporal.resultType)
val convRightT = convertOverlapsEnd(relBuilder, rightP, rightT, rightTemporal.resultType)
// sort end points into start and end, such that (s0 <= e0) and (s1 <= e1).
val (s0, e0) = buildSwap(relBuilder, leftP, convLeftT)
val (s1, e1) = buildSwap(relBuilder, rightP, convRightT)
// (e0 >= s1) AND (e1 >= s0)
val leftPred = relBuilder.call(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, e0, s1)
val rightPred = relBuilder.call(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, e1, s0)
relBuilder.call(SqlStdOperatorTable.AND, leftPred, rightPred)
}
private def convertOverlapsEnd(
relBuilder: FlinkRelBuilder,
start: RexNode, end: RexNode,
endType: InternalType) = {
if (isTimeInterval(endType)) {
relBuilder.call(SqlStdOperatorTable.DATETIME_PLUS, start, end)
} else {
end
}
}
private def buildSwap(relBuilder: FlinkRelBuilder, start: RexNode, end: RexNode) = {
val le = relBuilder.call(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, start, end)
val l = relBuilder.call(SqlStdOperatorTable.CASE, le, start, end)
val r = relBuilder.call(SqlStdOperatorTable.CASE, le, end, start)
(l, r)
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class DateFormat(timestamp: Expression, format: Expression) extends Expression {
override private[flink] def children = timestamp :: format :: Nil
override private[flink] def toRexNode(implicit relBuilder: RelBuilder) =
relBuilder.call(ScalarSqlFunctions.DATE_FORMAT, timestamp.toRexNode, format.toRexNode)
override def toString: String = s"$timestamp.dateFormat($format)"
override private[flink] def resultType = DataTypes.STRING
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class TimestampDiff(
timePointUnit: Expression,
timePoint1: Expression,
timePoint2: Expression)
extends Expression {
override private[flink] def children: Seq[Expression] =
timePointUnit :: timePoint1 :: timePoint2 :: Nil
override private[flink] def validateInput(): ValidationResult = {
if (!TypeCheckUtils.isTimePoint(timePoint1.resultType)) {
return ValidationFailure(
s"$this requires an input time point type, " +
s"but timePoint1 is of type '${timePoint1.resultType}'.")
}
if (!TypeCheckUtils.isTimePoint(timePoint2.resultType)) {
return ValidationFailure(
s"$this requires an input time point type, " +
s"but timePoint2 is of type '${timePoint2.resultType}'.")
}
timePointUnit match {
case SymbolExpression(TimePointUnit.YEAR)
| SymbolExpression(TimePointUnit.QUARTER)
| SymbolExpression(TimePointUnit.MONTH)
| SymbolExpression(TimePointUnit.WEEK)
| SymbolExpression(TimePointUnit.DAY)
| SymbolExpression(TimePointUnit.HOUR)
| SymbolExpression(TimePointUnit.MINUTE)
| SymbolExpression(TimePointUnit.SECOND)
if timePoint1.resultType == DataTypes.DATE
|| timePoint1.resultType == DataTypes.TIMESTAMP
|| timePoint2.resultType == DataTypes.DATE
|| timePoint2.resultType == DataTypes.TIMESTAMP =>
ValidationSuccess
case _ =>
ValidationFailure(s"$this operator does not support unit '$timePointUnit'" +
s" for input of type ('${timePoint1.resultType}', '${timePoint2.resultType}').")
}
}
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder
.getRexBuilder
.makeCall(SqlStdOperatorTable.TIMESTAMP_DIFF,
Seq(timePointUnit.toRexNode, timePoint2.toRexNode, timePoint1.toRexNode))
}
override def toString: String = s"timestampDiff(${children.mkString(", ")})"
override private[flink] def resultType = DataTypes.INT
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}