blob: d42d7e2cf3dce1f1a77af3c07d9775ab0288102b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.logical
import java.lang.reflect.Method
import java.util
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{CorrelationId, JoinRelType}
import org.apache.calcite.rel.logical.LogicalTableFunctionScan
import org.apache.calcite.rex.{RexInputRef, RexNode}
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.table.api.functions.TableFunction
import org.apache.flink.table.api.types.{DataType, DataTypes, InternalType, TypeConverters}
import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment, UnresolvedException}
import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
import org.apache.flink.table.expressions.ExpressionUtils.isRowCountLiteral
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.plan.schema.TypedFlinkTableFunction
import org.apache.flink.table.sinks.TableSink
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.table.validate.{ValidationFailure, ValidationSuccess}
import scala.collection.JavaConverters._
import scala.collection.JavaConversions._
import scala.collection.mutable
case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
val afterResolve = super.resolveExpressions(tableEnv).asInstanceOf[Project]
val newProjectList =
afterResolve.projectList.zipWithIndex.map { case (e, i) =>
e match {
case u @ UnresolvedAlias(c) => c match {
case ne: NamedExpression => ne
case expr if !expr.valid => u
case c @ Cast(ne: NamedExpression, tp) => Alias(c, s"${ne.name}-$tp")
case gcf: GetCompositeField => Alias(gcf, gcf.aliasName().getOrElse(s"_c$i"))
case other => Alias(other, s"_c$i")
}
case _ =>
throw new RuntimeException("This should never be called and probably points to a bug.")
}
}
Project(newProjectList, child)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
val resolvedProject = super.validate(tableEnv).asInstanceOf[Project]
val names: mutable.Set[String] = mutable.Set()
def checkName(name: String): Unit = {
if (names.contains(name)) {
failValidation(s"Duplicate field name $name.")
} else {
names.add(name)
}
}
resolvedProject.projectList.foreach {
case n: Alias =>
// explicit name
checkName(n.name)
case r: ResolvedFieldReference =>
// simple field forwarding
checkName(r.name)
case _ => // Do nothing
}
resolvedProject
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)
relBuilder.project(
projectList.map(_.toRexNode(relBuilder)).asJava,
projectList.map(_.name).asJava,
true)
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class AliasNode(aliasList: Seq[Expression], child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] =
throw UnresolvedException("Invalid call to output on AliasNode")
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder =
throw UnresolvedException("Invalid call to toRelNode on AliasNode")
override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
if (aliasList.length > child.output.length) {
failValidation("Aliasing more fields than we actually have")
} else if (!aliasList.forall(_.isInstanceOf[UnresolvedFieldReference])) {
failValidation("Alias only accept name expressions as arguments")
} else if (!aliasList.forall(_.asInstanceOf[UnresolvedFieldReference].name != "*")) {
failValidation("Alias can not accept '*' as name")
} else {
val names = aliasList.map(_.asInstanceOf[UnresolvedFieldReference].name)
val input = child.output
Project(
names.zip(input).map { case (name, attr) =>
Alias(attr, name)} ++ input.drop(names.length), child)
}
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Distinct(child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)
relBuilder.distinct()
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Sort(order: Seq[Ordering], child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)
relBuilder.sort(order.map(_.toRexNode(relBuilder)).asJava)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
failValidation(s"Sort on stream tables is currently not supported.")
}
super.validate(tableEnv)
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Limit(offset: Int, fetch: Int = -1, child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)
relBuilder.limit(offset, fetch)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
failValidation(s"Limit on stream tables is currently not supported.")
}
if (offset < 0) {
failValidation(s"Offset should be greater than or equal to zero.")
}
super.validate(tableEnv)
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Filter(condition: Expression, child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)
relBuilder.filter(condition.toRexNode(relBuilder))
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
val resolvedFilter = super.validate(tableEnv).asInstanceOf[Filter]
if (resolvedFilter.condition.resultType != DataTypes.BOOLEAN) {
failValidation(s"Filter operator requires a boolean expression as input," +
s" but ${resolvedFilter.condition} is of type ${resolvedFilter.condition.resultType}")
}
resolvedFilter
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] = {
(groupingExpressions ++ aggregateExpressions) map {
case ne: NamedExpression => ne.toAttribute
case e => Alias(e, e.toString).toAttribute
}
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)
relBuilder.aggregate(
relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
aggregateExpressions.map {
case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
}.asJava)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate]
val groupingExprs = resolvedAggregate.groupingExpressions
val aggregateExprs = resolvedAggregate.aggregateExpressions
aggregateExprs.foreach(validateAggregateExpression)
groupingExprs.foreach(validateGroupingExpression)
def validateAggregateExpression(expr: Expression): Unit = expr match {
case distinctExpr: DistinctAgg =>
distinctExpr.child match {
case _: DistinctAgg => failValidation(
"Chained distinct operators are not supported!")
case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
case _ => failValidation(
"Distinct operator can only be applied to aggregation expressions!")
}
// check aggregate function
case aggExpr: Aggregation
if aggExpr.getSqlAggFunction.requiresOver =>
failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
// check no nested aggregation exists.
case aggExpr: Aggregation =>
aggExpr.children.foreach { child =>
child.preOrderVisit {
case agg: Aggregation =>
failValidation(
"It's not allowed to use an aggregate function as " +
"input of another aggregate function")
case _ => // OK
}
}
case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
failValidation(
s"expression '$a' is invalid because it is neither" +
" present in group by nor an aggregate function")
case e if groupingExprs.exists(_.checkEquals(e)) => // OK
case e => e.children.foreach(validateAggregateExpression)
}
def validateGroupingExpression(expr: Expression): Unit = {
if (!TypeConverters.createExternalTypeInfoFromDataType(expr.resultType).isKeyType) {
failValidation(
s"expression $expr cannot be used as a grouping expression " +
"because it's not a valid key type which must be hashable and comparable")
}
}
resolvedAggregate
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Minus(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = left.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
relBuilder.minus(all)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
failValidation(s"Minus on stream tables is currently not supported.")
}
val resolvedMinus = super.validate(tableEnv).asInstanceOf[Minus]
if (left.output.length != right.output.length) {
failValidation(s"Minus two table of different column sizes:" +
s" ${left.output.size} and ${right.output.size}")
}
val sameSchema = left.output.zip(right.output).forall { case (l, r) =>
l.resultType == r.resultType
}
if (!sameSchema) {
failValidation(s"Minus two table of different schema:" +
s" [${left.output.map(a => (a.name, a.resultType)).mkString(", ")}] and" +
s" [${right.output.map(a => (a.name, a.resultType)).mkString(", ")}]")
}
resolvedMinus
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Union(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = left.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
relBuilder.union(all)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
if (tableEnv.isInstanceOf[StreamTableEnvironment] && !all) {
failValidation(s"Union on stream tables is currently not supported.")
}
val resolvedUnion = super.validate(tableEnv).asInstanceOf[Union]
if (left.output.length != right.output.length) {
failValidation(s"Union two tables of different column sizes:" +
s" ${left.output.size} and ${right.output.size}")
}
val sameSchema = left.output.zip(right.output).forall { case (l, r) =>
l.resultType == r.resultType
}
if (!sameSchema) {
failValidation(s"Union two tables of different schema:" +
s" [${left.output.map(a => (a.name, a.resultType)).mkString(", ")}] and" +
s" [${right.output.map(a => (a.name, a.resultType)).mkString(", ")}]")
}
resolvedUnion
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class Intersect(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = left.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
relBuilder.intersect(all)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
failValidation(s"Intersect on stream tables is currently not supported.")
}
val resolvedIntersect = super.validate(tableEnv).asInstanceOf[Intersect]
if (left.output.length != right.output.length) {
failValidation(s"Intersect two tables of different column sizes:" +
s" ${left.output.size} and ${right.output.size}")
}
// allow different column names between tables
val sameSchema = left.output.zip(right.output).forall { case (l, r) =>
l.resultType == r.resultType
}
if (!sameSchema) {
failValidation(s"Intersect two tables of different schema:" +
s" [${left.output.map(a => (a.name, a.resultType)).mkString(", ")}] and" +
s" [${right.output.map(a => (a.name, a.resultType)).mkString(", ")}]")
}
resolvedIntersect
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class JoinFieldReference(
name: String,
resultType: InternalType,
left: LogicalNode,
right: LogicalNode) extends Attribute {
val isFromLeftInput: Boolean = left.output.map(_.name).contains(name)
val (indexInInput, indexInJoin) = if (isFromLeftInput) {
val indexInLeft = left.output.map(_.name).indexOf(name)
(indexInLeft, indexInLeft)
} else {
val indexInRight = right.output.map(_.name).indexOf(name)
(indexInRight, indexInRight + left.output.length)
}
override def toString = s"'$name"
override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
// look up type of field
val fieldType = relBuilder.field(2, if (isFromLeftInput) 0 else 1, name).getType
// create a new RexInputRef with index offset
new RexInputRef(indexInJoin, fieldType)
}
override def withName(newName: String): Attribute = {
if (newName == name) {
this
} else {
JoinFieldReference(newName, resultType, left, right)
}
}
override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
s"`$name`".asInstanceOf[T]
}
case class Join(
left: LogicalNode,
right: LogicalNode,
joinType: JoinType,
condition: Option[Expression],
correlated: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = {
left.output ++ right.output
}
override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
val node = super.resolveExpressions(tableEnv).asInstanceOf[Join]
val partialFunction: PartialFunction[Expression, Expression] = {
case field: ResolvedFieldReference => JoinFieldReference(
field.name,
field.resultType,
left,
right)
}
val resolvedCondition = node.condition.map(_.postOrderTransform(partialFunction))
Join(node.left, node.right, node.joinType, resolvedCondition, correlated)
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
val corSet = mutable.Set[CorrelationId]()
if (correlated) {
corSet += relBuilder.peek().getCluster.createCorrel()
}
relBuilder.join(
convertJoinType(joinType),
condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)),
corSet.asJava)
}
private def convertJoinType(joinType: JoinType) = joinType match {
case JoinType.INNER => JoinRelType.INNER
case JoinType.LEFT_OUTER => JoinRelType.LEFT
case JoinType.RIGHT_OUTER => JoinRelType.RIGHT
case JoinType.FULL_OUTER => JoinRelType.FULL
}
private def ambiguousName: Set[String] =
left.output.map(_.name).toSet.intersect(right.output.map(_.name).toSet)
override def validate(tableEnv: TableEnvironment): LogicalNode = {
val resolvedJoin = super.validate(tableEnv).asInstanceOf[Join]
if (!resolvedJoin.condition.forall(_.resultType == DataTypes.BOOLEAN)) {
failValidation(s"Filter operator requires a boolean expression as input, " +
s"but ${resolvedJoin.condition} is of type ${resolvedJoin.joinType}")
} else if (ambiguousName.nonEmpty) {
failValidation(s"join relations with ambiguous names: ${ambiguousName.mkString(", ")}")
}
resolvedJoin.condition.foreach(testJoinCondition)
resolvedJoin
}
private def testJoinCondition(expression: Expression): Unit = {
def checkIfJoinCondition(exp: BinaryComparison) = exp.children match {
case (x: JoinFieldReference) :: (y: JoinFieldReference) :: Nil
if x.isFromLeftInput != y.isFromLeftInput => true
case _ => false
}
def checkIfFilterCondition(exp: BinaryComparison) = exp.children match {
case (x: JoinFieldReference) :: (y: JoinFieldReference) :: Nil => false
case (x: JoinFieldReference) :: (_) :: Nil => true
case (_) :: (y: JoinFieldReference) :: Nil => true
case _ => false
}
var equiJoinPredicateFound = false
// Whether the predicate is literal true.
val alwaysTrue = expression match {
case x: Literal if x.value.equals(true) => true
case _ => false
}
def validateConditions(exp: Expression, isAndBranch: Boolean): Unit = exp match {
case x: And => x.children.foreach(validateConditions(_, isAndBranch))
case x: Or => x.children.foreach(validateConditions(_, isAndBranch = false))
case x: EqualTo =>
if (isAndBranch && checkIfJoinCondition(x)) {
equiJoinPredicateFound = true
}
case x: BinaryComparison =>
// The boolean literal should be a valid condition type.
case x: Literal if x.resultType == DataTypes.BOOLEAN =>
case x => failValidation(
s"Unsupported condition type: ${x.getClass.getSimpleName}. Condition: $x")
}
validateConditions(expression, isAndBranch = true)
// Due to a bug in Apache Calcite (see CALCITE-2004 and FLINK-7865) we cannot accept join
// predicates except literal true for TableFunction left outer join.
if (correlated && right.isInstanceOf[LogicalTableFunctionCall] && joinType != JoinType.INNER ) {
if (!alwaysTrue) failValidation("TableFunction left outer join predicate can only be " +
"empty or literal true.")
}
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class CatalogNode(
tablePath: Seq[String],
rowType: RelDataType) extends LeafNode {
val output: Seq[Attribute] = rowType.getFieldList.asScala.map { field =>
ResolvedFieldReference(field.getName, FlinkTypeFactory.toInternalType(field.getType))
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
relBuilder.scan(tablePath.asJava)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = this
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class SinkNode(child: LogicalNode, sink: TableSink[_], sinkName: String) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder]
child.construct(flinkRelBuilder)
flinkRelBuilder.sink(sink, sinkName)
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
/**
* Wrapper for valid logical plans generated from SQL String.
*/
case class LogicalRelNode(
relNode: RelNode) extends LeafNode {
val output: Seq[Attribute] = relNode.getRowType.getFieldList.asScala.map { field =>
ResolvedFieldReference(field.getName, FlinkTypeFactory.toInternalType(field.getType))
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
relBuilder.push(relNode)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = this
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class WindowAggregate(
groupingExpressions: Seq[Expression],
window: LogicalWindow,
propertyExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[NamedExpression],
child: LogicalNode)
extends UnaryNode {
override def output: Seq[Attribute] = {
(groupingExpressions ++ aggregateExpressions ++ propertyExpressions) map {
case ne: NamedExpression => ne.toAttribute
case e => Alias(e, e.toString).toAttribute
}
}
// resolve references of this operator's parameters
override def resolveReference(
tableEnv: TableEnvironment,
name: String)
: Option[NamedExpression] = {
def resolveAlias(alias: String) = {
// check if reference can already be resolved by input fields
val found = super.resolveReference(tableEnv, name)
if (found.isDefined) {
failValidation(s"Reference $name is ambiguous.")
} else {
// resolve type of window reference
val resolvedType = window.timeAttribute match {
case UnresolvedFieldReference(n) =>
super.resolveReference(tableEnv, n) match {
case Some(ResolvedFieldReference(_, tpe)) => Some(tpe)
case _ => None
}
case _ => None
}
// let validation phase throw an error if type could not be resolved
Some(WindowReference(name, resolvedType))
}
}
window.aliasAttribute match {
// resolve reference to this window's name
case UnresolvedFieldReference(alias) if name == alias =>
resolveAlias(alias)
case _ =>
// resolve references as usual
super.resolveReference(tableEnv, name)
}
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder]
child.construct(flinkRelBuilder)
flinkRelBuilder.aggregate(
window,
relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
propertyExpressions.map {
case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)
case _ => throw new RuntimeException("This should never happen.")
},
aggregateExpressions.map {
case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
}.asJava)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[WindowAggregate]
val groupingExprs = resolvedWindowAggregate.groupingExpressions
val aggregateExprs = resolvedWindowAggregate.aggregateExpressions
aggregateExprs.foreach(validateAggregateExpression)
groupingExprs.foreach(validateGroupingExpression)
def validateAggregateExpression(expr: Expression): Unit = expr match {
// check aggregate function
case aggExpr: Aggregation
if aggExpr.getSqlAggFunction.requiresOver =>
failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
// check no nested aggregation exists.
case aggExpr: Aggregation =>
aggExpr.children.foreach { child =>
child.preOrderVisit {
case agg: Aggregation =>
failValidation(
"It's not allowed to use an aggregate function as " +
"input of another aggregate function")
case _ => // ok
}
}
case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
failValidation(
s"Expression '$a' is invalid because it is neither" +
" present in group by nor an aggregate function")
case e if groupingExprs.exists(_.checkEquals(e)) => // ok
case e => e.children.foreach(validateAggregateExpression)
}
def validateGroupingExpression(expr: Expression): Unit = {
if (!TypeConverters.createExternalTypeInfoFromDataType(expr.resultType).isKeyType) {
failValidation(
s"Expression $expr cannot be used as a grouping expression " +
"because it's not a valid key type which must be hashable and comparable")
}
}
// validate window
resolvedWindowAggregate.window.validate(tableEnv) match {
case ValidationFailure(msg) =>
failValidation(s"$window is invalid: $msg")
case ValidationSuccess => // ok
}
// validate property
if (propertyExpressions.nonEmpty) {
resolvedWindowAggregate.window match {
case TumblingGroupWindow(_, _, size) if isRowCountLiteral(size) =>
failValidation("Window start and Window end cannot be selected " +
"for a row-count Tumbling window.")
case SlidingGroupWindow(_, _, size, _) if isRowCountLiteral(size) =>
failValidation("Window start and Window end cannot be selected " +
"for a row-count Sliding window.")
case _ => // ok
}
}
resolvedWindowAggregate
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}
case class TemporalTable(
timeAttribute: Expression,
primaryKey: Expression,
child: LogicalNode)
extends UnaryNode {
override def output: Seq[Attribute] = child.output
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
throw new UnsupportedOperationException(
"This should never be called. This node is supposed to be used only for validation")
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T = {
logicalSqlVisitor.visit(this)
}
}
/**
* LogicalNode for calling a user-defined table functions.
*
* @param functionName function name
* @param tableFunction table function to be called (might be overloaded)
* @param parameters actual parameters
* @param fieldNames output field names
* @param input input node of table function
*/
case class LogicalTableFunctionCall(
functionName: String,
tableFunction: TableFunction[_],
parameters: Seq[Expression],
externalResultType: DataType,
fieldNames: Array[String],
input: LogicalNode)
extends LeafNode {
private val (generatedNames, fieldIndexes, fieldTypes) = getFieldInfo(externalResultType)
private var evalMethod: Method = _
override def output: Seq[Attribute] = {
if (fieldNames.isEmpty) {
generatedNames.zip(fieldTypes).map {
case (n, t) => ResolvedFieldReference(n, t)
}
} else {
fieldNames.zip(fieldTypes).map {
case (n, t) => ResolvedFieldReference(n, t)
}
}
}
override def resolveReference(
tableEnv: TableEnvironment,
name: String): Option[NamedExpression] = {
// try to resolve a field
val childrenOutput = input.output
val fieldCandidates = childrenOutput.filter(_.name.equals(name))
if (fieldCandidates.isEmpty) {
val from = childrenOutput.map(_.name).mkString(", ")
failValidation(s"""Cannot resolve [$name] given input [$from].""")
}
if (fieldCandidates.size > 1) {
failValidation(s"Reference $name is ambiguous.")
}
Some(fieldCandidates.head.withName(name))
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall]
// check if not Scala object
checkNotSingleton(tableFunction.getClass)
// check if class could be instantiated
checkForInstantiation(tableFunction.getClass)
// look for a signature that matches the input types
val signature = node.parameters.map(_.resultType)
val foundMethod = getEvalUserDefinedMethod(tableFunction, signature)
if (foundMethod.isEmpty) {
failValidation(
s"Given parameters of function '$functionName' do not match any signature. \n" +
s"Actual: ${signatureToString(signature)} \n" +
s"Expected: ${signaturesToString(tableFunction, "eval")}")
} else {
node.evalMethod = foundMethod.get
}
node
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
val function = new TypedFlinkTableFunction(
tableFunction,
externalResultType)
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
val sqlFunction = new TableSqlFunction(
tableFunction.functionIdentifier,
tableFunction.toString,
tableFunction,
externalResultType,
typeFactory,
function)
val scan = LogicalTableFunctionScan.create(
relBuilder.peek().getCluster,
new util.ArrayList[RelNode](),
relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava),
function.getElementType(null),
UserDefinedFunctionUtils.buildRelDataType(
relBuilder.getTypeFactory,
externalResultType.toInternalType,
if (fieldNames.isEmpty) generatedNames else fieldNames,
fieldIndexes),
null)
relBuilder.push(scan)
}
override def accept[T](logicalSqlVisitor: LogicalNodeVisitor[T]): T =
logicalSqlVisitor.visit(this)
}