blob: c6c49761d80ca017c942bc629a7e228d807a8cf1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.expressions
import java.util
import java.math.BigDecimal
import scala.collection.mutable
import com.google.common.collect.ImmutableList
import org.apache.calcite.rel.RelFieldCollation
import org.apache.calcite.rex.RexWindowBound._
import org.apache.calcite.rex.{RexCall, RexFieldCollation, RexNode, RexWindowBound}
import org.apache.calcite.sql._
import org.apache.calcite.sql.`type`.OrdinalReturnTypeInference
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.api._
import org.apache.flink.table.api.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.api.scala.{CURRENT_ROW, UNBOUNDED_ROW}
import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkTypeFactory}
import org.apache.flink.table.functions.sql.internal.SqlThrowExceptionFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.plan.logical.{LogicalExprVisitor, LogicalNode, LogicalTableFunctionCall}
import org.apache.flink.table.api.types._
import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
import _root_.scala.collection.JavaConverters._
/**
* General expression for unresolved function calls. The function can be a built-in
* scalar function or a user-defined scalar function.
*/
case class Call(functionName: String, args: Seq[Expression]) extends Expression {
override private[flink] def children: Seq[Expression] = args
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
throw UnresolvedException(s"trying to convert UnresolvedFunction $functionName to RexNode")
}
override def toString = s"\\$functionName(${args.mkString(", ")})"
override private[flink] def resultType =
throw UnresolvedException(s"calling resultType on UnresolvedFunction $functionName")
override private[flink] def validateInput(): ValidationResult =
ValidationFailure(s"Unresolved function call: $functionName")
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
/**
* Over call with unresolved alias for over window.
*
* @param agg The aggregation of the over call.
* @param alias The alias of the referenced over window.
*/
case class UnresolvedOverCall(agg: Expression, alias: Expression) extends Expression {
override private[flink] def validateInput() =
ValidationFailure(s"Over window with alias $alias could not be resolved.")
override private[flink] def resultType = agg.resultType
override private[flink] def children = Seq()
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
/**
* Over expression for Calcite over transform.
*
* @param agg over-agg expression
* @param partitionBy The fields by which the over window is partitioned
* @param orderBy The field by which the over window is sorted
* @param preceding The lower bound of the window
* @param following The upper bound of the window
*/
case class OverCall(
agg: Expression,
partitionBy: Seq[Expression],
orderBy: Seq[Expression],
var preceding: Expression,
var following: Expression,
tableEnv: TableEnvironment) extends Expression {
override def toString: String = s"$agg OVER (" +
s"PARTITION BY (${partitionBy.mkString(", ")}) " +
s"ORDER BY (${orderBy.mkString(", ")})" +
s"PRECEDING $preceding " +
s"FOLLOWING $following)"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val rexBuilder = relBuilder.getRexBuilder
// assemble aggregation
val operator: SqlAggFunction = agg.asInstanceOf[Aggregation].getSqlAggFunction()
val aggResultType = relBuilder
.getTypeFactory.asInstanceOf[FlinkTypeFactory]
.createTypeFromInternalType(agg.resultType, isNullable = true)
// assemble exprs by agg children
val aggExprs = agg.asInstanceOf[Aggregation].children.map(_.toRexNode(relBuilder)).asJava
// assemble order by key
def collation(
node: RexNode, direction: RelFieldCollation.Direction,
nullDirection: RelFieldCollation.NullDirection, kinds: mutable.Set[SqlKind]): RexNode = {
node.getKind match {
case SqlKind.DESCENDING =>
kinds.add(node.getKind)
collation(
node.asInstanceOf[RexCall].getOperands.get(0),
RelFieldCollation.Direction.DESCENDING, nullDirection, kinds)
case SqlKind.NULLS_FIRST =>
kinds.add(node.getKind)
collation(
node.asInstanceOf[RexCall].getOperands.get(0), direction,
RelFieldCollation.NullDirection.FIRST, kinds)
case SqlKind.NULLS_LAST =>
kinds.add(node.getKind)
collation(
node.asInstanceOf[RexCall].getOperands.get(0), direction,
RelFieldCollation.NullDirection.LAST, kinds)
case _ =>
if (nullDirection == null) {
// Set the null direction if not specified.
// Consistent with HIVE/SPARK/MYSQL/BLINK-RUNTIME.
if (FlinkPlannerImpl.defaultNullCollation.last(
direction.equals(RelFieldCollation.Direction.DESCENDING))) {
kinds.add(SqlKind.NULLS_LAST)
} else {
kinds.add(SqlKind.NULLS_FIRST)
}
}
node
}
}
val orderKeysBuilder = new ImmutableList.Builder[RexFieldCollation]()
for (orderExp <- orderBy) {
val kinds: mutable.Set[SqlKind] = mutable.Set()
val rexNode = collation(
orderExp.toRexNode, RelFieldCollation.Direction.ASCENDING, null, kinds)
val orderKey = new RexFieldCollation(rexNode, kinds.asJava)
orderKeysBuilder.add(orderKey)
}
val orderKeys = orderKeysBuilder.build()
// assemble partition by keys
val partitionKeys = partitionBy.map(_.toRexNode).asJava
// assemble bounds
var isPhysical: Boolean = false
var lowerBound, upperBound: RexWindowBound = null
agg match {
case _: RowNumber | _: Rank | _: DenseRank =>
isPhysical = true
lowerBound = createBound(relBuilder, UNBOUNDED_ROW, SqlKind.PRECEDING)
upperBound = createBound(relBuilder, CURRENT_ROW, SqlKind.FOLLOWING)
case _ =>
isPhysical = preceding.resultType == DataTypes.INTERVAL_ROWS
lowerBound = createBound(relBuilder, preceding, SqlKind.PRECEDING)
upperBound = createBound(relBuilder, following, SqlKind.FOLLOWING)
}
// build RexOver
rexBuilder.makeOver(
aggResultType,
operator,
aggExprs,
partitionKeys,
orderKeys,
lowerBound,
upperBound,
isPhysical,
true,
false,
false)
}
private def createBound(
relBuilder: RelBuilder,
bound: Expression,
sqlKind: SqlKind): RexWindowBound = {
bound match {
case _: UnboundedRow | _: UnboundedRange =>
val unbounded = if (sqlKind == SqlKind.PRECEDING) {
SqlWindow.createUnboundedPreceding(SqlParserPos.ZERO)
} else {
SqlWindow.createUnboundedFollowing(SqlParserPos.ZERO)
}
create(unbounded, null)
case _: CurrentRow | _: CurrentRange =>
val currentRow = SqlWindow.createCurrentRow(SqlParserPos.ZERO)
create(currentRow, null)
case b: Literal =>
val DECIMAL_PRECISION_NEEDED_FOR_LONG = 19
val returnType = relBuilder
.getTypeFactory.asInstanceOf[FlinkTypeFactory]
.createTypeFromInternalType(
DecimalType.of(DECIMAL_PRECISION_NEEDED_FOR_LONG, 0),
isNullable = true)
val sqlOperator = new SqlPostfixOperator(
sqlKind.name,
sqlKind,
2,
new OrdinalReturnTypeInference(0),
null,
null)
val operands: Array[SqlNode] = new Array[SqlNode](1)
operands(0) = SqlLiteral.createExactNumeric("1", SqlParserPos.ZERO)
val node = new SqlBasicCall(sqlOperator, operands, SqlParserPos.ZERO)
val expressions: util.ArrayList[RexNode] = new util.ArrayList[RexNode]()
b.value match {
case v: Double => expressions.add(relBuilder.literal(
BigDecimal.valueOf(v.asInstanceOf[Number].doubleValue())))
case _ => expressions.add(relBuilder.literal(b.value))
}
val rexNode = relBuilder.getRexBuilder.makeCall(returnType, sqlOperator, expressions)
create(node, rexNode)
}
}
override private[flink] def children: Seq[Expression] =
Seq(agg) ++ orderBy ++ partitionBy ++ Seq(preceding) ++ Seq(following)
override private[flink] def resultType = agg.resultType
override private[flink] def validateInput(): ValidationResult = {
// check that agg expression is aggregation
agg match {
case _: Aggregation =>
ValidationSuccess
case _ =>
return ValidationFailure(s"OVER can only be applied on an aggregation.")
}
// check partitionBy expression keys are resolved field reference
partitionBy.foreach {
case r: ResolvedFieldReference
if TypeConverters.createExternalTypeInfoFromDataType(r.resultType).isKeyType =>
ValidationSuccess
case r: ResolvedFieldReference =>
return ValidationFailure(s"Invalid PartitionBy expression: $r. " +
s"Expression must return key type.")
case r =>
return ValidationFailure(s"Invalid PartitionBy expression: $r. " +
s"Expression must be a resolved field reference.")
}
// check preceding is valid
preceding match {
case _: CurrentRow | _: CurrentRange | _: UnboundedRow | _: UnboundedRange =>
ValidationSuccess
case Literal(_: Long, DataTypes.INTERVAL_ROWS) =>
ValidationSuccess
case Literal(_: Long, DataTypes.INTERVAL_RANGE) |
Literal(_: Double, DataTypes.INTERVAL_RANGE) =>
ValidationSuccess
case Literal(v: Long, b) if isTimeInterval(b) && v >= 0 =>
ValidationSuccess
case Literal(_, b) if isTimeInterval(b) =>
return ValidationFailure("Preceding time interval must be equal or larger than 0.")
case Literal(_, _) =>
return ValidationFailure("Preceding must be a row interval or time interval literal.")
}
// check following is valid
following match {
case _: CurrentRow | _: CurrentRange | _: UnboundedRow | _: UnboundedRange =>
ValidationSuccess
case Literal(_: Long, DataTypes.INTERVAL_ROWS) =>
ValidationSuccess
case Literal(_: Long, DataTypes.INTERVAL_RANGE) |
Literal(_: Double, DataTypes.INTERVAL_RANGE) =>
ValidationSuccess
case Literal(v: Long, b) if isTimeInterval(b) && v >= 0 =>
ValidationSuccess
case Literal(_, b) if isTimeInterval(b) =>
return ValidationFailure("Following time interval must be equal or larger than 0.")
case Literal(_, _) =>
return ValidationFailure("Following must be a row interval or time interval literal.")
}
// check that preceding and following are of same type
(preceding, following) match {
case (p: Expression, f: Expression) if p.resultType == f.resultType =>
ValidationSuccess
case (_: UnboundedRange, f: Expression) if f.resultType == DataTypes.INTERVAL_MILLIS =>
ValidationSuccess
case (_: CurrentRange, f: Expression) if f.resultType == DataTypes.INTERVAL_MILLIS =>
ValidationSuccess
case (p: Expression, _: UnboundedRange) if p.resultType == DataTypes.INTERVAL_MILLIS =>
ValidationSuccess
case (p: Expression, _: CurrentRange) if p.resultType == DataTypes.INTERVAL_MILLIS =>
ValidationSuccess
case _ =>
return ValidationFailure("Preceding and following must be of same interval type.")
}
if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
validateInputForStream()
} else {
ValidationSuccess
}
}
private[flink] def validateInputForStream(): ValidationResult = {
// check preceding/following range
preceding match {
case Literal(_, DataTypes.INTERVAL_RANGE) =>
return ValidationFailure("Stream table API does not support value range.")
case Literal(v: Long, DataTypes.INTERVAL_ROWS) if v < 0 =>
return ValidationFailure("Stream table API does not support negative preceding.")
case _ =>
ValidationSuccess
}
following match {
case _: UnboundedRow | _: UnboundedRange =>
return ValidationFailure("Stream table API does not support unbounded following.")
case Literal(_, DataTypes.INTERVAL_RANGE) =>
return ValidationFailure("Stream table API does not support value range.")
case Literal(v: Long, DataTypes.INTERVAL_ROWS) if v > 0 =>
return ValidationFailure("Stream table API does not support positive following.")
case _ =>
ValidationSuccess
}
// check lead/lag offset
agg match {
case a: Lead =>
if (a.getOffsetValue > 0) {
return ValidationFailure("Stream table API does not support positive offset for lead.")
}
case a: Lag =>
if (a.getOffsetValue < 0) {
return ValidationFailure("Stream table API does not support negative offset for lag.")
}
case _ =>
ValidationSuccess
}
ValidationSuccess
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
object ScalarFunctionCall {
def apply(
scalarFunction: ScalarFunction,
parameters: Array[Expression]): ScalarFunctionCall =
new ScalarFunctionCall(scalarFunction, parameters)
}
/**
* Expression for calling a user-defined scalar functions.
*
* @param scalarFunction scalar function to be called (might be overloaded)
* @param parameters actual parameters that determine target evaluation method
*/
case class ScalarFunctionCall(
scalarFunction: ScalarFunction,
parameters: Seq[Expression])
extends Expression {
override private[flink] def children: Seq[Expression] = parameters
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
relBuilder.call(
createScalarSqlFunction(
scalarFunction.functionIdentifier,
scalarFunction.toString,
scalarFunction,
typeFactory),
parameters.map(_.toRexNode): _*)
}
override def toString =
s"${scalarFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"
override private[flink] def resultType =
getResultTypeOfCTDFunction(
scalarFunction,
parameters.toArray,
() => {extractTypeFromScalarFunc(
scalarFunction, parameters.map(_.resultType).toArray)}).toInternalType
override private[flink] def validateInput(): ValidationResult = {
val signature = children.map(_.resultType)
// look for a signature that matches the input types
if (getEvalUserDefinedMethod(scalarFunction, signature).isEmpty) {
ValidationFailure(s"Given parameters do not match any signature. \n" +
s"Actual: ${signatureToString(signature)} \n" +
s"Expected: ${signaturesToString(scalarFunction, "eval")}")
} else {
ValidationSuccess
}
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
/**
*
* Expression for calling a user-defined table function with actual parameters.
*
* @param functionName function name
* @param tableFunction user-defined table function
* @param parameters actual parameters of function
* @param externalResultType type information of returned table
*/
case class TableFunctionCall(
functionName: String,
tableFunction: TableFunction[_],
parameters: Seq[Expression],
externalResultType: DataType)
extends Expression {
private var aliases: Option[Seq[String]] = None
override private[flink] def children: Seq[Expression] = parameters
override private[flink] def resultType: InternalType = externalResultType.toInternalType
/**
* Assigns an alias for this table function's returned fields that the following operator
* can refer to.
*
* @param aliasList alias for this table function's returned fields
* @return this table function call
*/
private[flink] def as(aliasList: Option[Seq[String]]): TableFunctionCall = {
this.aliases = aliasList
this
}
/**
* Converts an API class to a logical node for planning.
*/
private[flink] def toLogicalTableFunctionCall(child: LogicalNode): LogicalTableFunctionCall = {
val originNames = getFieldInfo(resultType)._1
// determine the final field names
val fieldNames = if (aliases.isDefined) {
val aliasList = aliases.get
if (aliasList.length != originNames.length) {
throw new ValidationException(
s"List of column aliases must have same degree as table; " +
s"the returned table of function '$functionName' has ${originNames.length} " +
s"columns (${originNames.mkString(",")}), " +
s"whereas alias list has ${aliasList.length} columns")
} else {
aliasList.toArray
}
} else {
originNames
}
LogicalTableFunctionCall(
functionName,
tableFunction,
parameters,
externalResultType,
fieldNames,
child)
}
override def toString =
s"${tableFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}
case class ThrowException(msg: Expression, tp: InternalType) extends UnaryExpression {
override private[flink] def validateInput(): ValidationResult = {
if (child.resultType == DataTypes.STRING) {
ValidationSuccess
} else {
ValidationFailure(s"ThrowException operator requires String input, " +
s"but $child is of type ${child.resultType}")
}
}
override def toString: String = s"ThrowException($msg)"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(
new SqlThrowExceptionFunction(
tp,
relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]),
msg.toRexNode)
}
override private[flink] def child = msg
override private[flink] def resultType = tp
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
logicalExprVisitor.visit(this)
}