blob: 138834bd69b872c47550ccb3d96aab41366ff983 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan
import org.apache.flink.table.api.types.{RowType}
import org.apache.flink.table.api.{OverWindow, TableEnvironment}
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.logical.{LogicalNode, Project}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object ProjectionTranslator {
/**
* Extracts and deduplicates all aggregation and window property expressions (zero, one, or more)
* from the given expressions.
*
* @param exprs a list of expressions to extract
* @param tableEnv the TableEnvironment
* @return a Tuple2, the first field contains the extracted and deduplicated aggregations,
* and the second field contains the extracted and deduplicated window properties.
*/
def extractAggregationsAndProperties(
exprs: Seq[Expression],
tableEnv: TableEnvironment): (Map[Expression, String], Map[Expression, String]) = {
exprs.foldLeft((Map[Expression, String](), Map[Expression, String]())) {
(x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
}
}
/** Identifies and deduplicates aggregation functions and window properties. */
private def identifyAggregationsAndProperties(
exp: Expression,
tableEnv: TableEnvironment,
aggNames: Map[Expression, String],
propNames: Map[Expression, String]) : (Map[Expression, String], Map[Expression, String]) = {
exp match {
case agg: Aggregation =>
if (aggNames contains agg) {
(aggNames, propNames)
} else {
(aggNames + (agg -> tableEnv.createUniqueAttributeName()), propNames)
}
case prop: WindowProperty =>
if (propNames contains prop) {
(aggNames, propNames)
} else {
(aggNames, propNames + (prop -> tableEnv.createUniqueAttributeName()))
}
case l: LeafExpression =>
(aggNames, propNames)
case u: UnaryExpression =>
identifyAggregationsAndProperties(u.child, tableEnv, aggNames, propNames)
case b: BinaryExpression =>
val l = identifyAggregationsAndProperties(b.left, tableEnv, aggNames, propNames)
identifyAggregationsAndProperties(b.right, tableEnv, l._1, l._2)
// Functions calls
case c @ Call(name, args) =>
args.foldLeft((aggNames, propNames)){
(x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
}
case sfc @ ScalarFunctionCall(clazz, args) =>
args.foldLeft((aggNames, propNames)){
(x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
}
// General expression
case e: Expression =>
e.productIterator.foldLeft((aggNames, propNames)){
(x, y) => y match {
case e: Expression => identifyAggregationsAndProperties(e, tableEnv, x._1, x._2)
case _ => (x._1, x._2)
}
}
}
}
/**
* Replaces expressions with deduplicated aggregations and properties.
*
* @param exprs a list of expressions to replace
* @param tableEnv the TableEnvironment
* @param aggNames the deduplicated aggregations
* @param propNames the deduplicated properties
* @return a list of replaced expressions
*/
def replaceAggregationsAndProperties(
exprs: Seq[Expression],
tableEnv: TableEnvironment,
aggNames: Map[Expression, String],
propNames: Map[Expression, String]): Seq[NamedExpression] = {
val projectedNames = new mutable.HashSet[String]
exprs.map((exp: Expression) => replaceAggregationsAndProperties(exp, tableEnv,
aggNames, propNames, projectedNames))
.map(UnresolvedAlias)
}
private def replaceAggregationsAndProperties(
exp: Expression,
tableEnv: TableEnvironment,
aggNames: Map[Expression, String],
propNames: Map[Expression, String],
projectedNames: mutable.HashSet[String]) : Expression = {
exp match {
case agg: Aggregation =>
val name = aggNames(agg)
if (projectedNames.add(name)) {
UnresolvedFieldReference(name)
} else {
Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
}
case prop: WindowProperty =>
val name = propNames(prop)
if (projectedNames.add(name)) {
UnresolvedFieldReference(name)
} else {
Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
}
case n @ Alias(agg: Aggregation, name, _) =>
val aName = aggNames(agg)
Alias(UnresolvedFieldReference(aName), name)
case n @ Alias(prop: WindowProperty, name, _) =>
val pName = propNames(prop)
Alias(UnresolvedFieldReference(pName), name)
case l: LeafExpression => l
case u: UnaryExpression =>
val c = replaceAggregationsAndProperties(u.child, tableEnv,
aggNames, propNames, projectedNames)
u.makeCopy(Array(c))
case b: BinaryExpression =>
val l = replaceAggregationsAndProperties(b.left, tableEnv,
aggNames, propNames, projectedNames)
val r = replaceAggregationsAndProperties(b.right, tableEnv,
aggNames, propNames, projectedNames)
b.makeCopy(Array(l, r))
// Functions calls
case c @ Call(name, args) =>
val newArgs = args.map((exp: Expression) =>
replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
c.makeCopy(Array(name, newArgs))
case sfc @ ScalarFunctionCall(clazz, args) =>
val newArgs: Seq[Expression] = args
.map((exp: Expression) =>
replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
sfc.makeCopy(Array(clazz, newArgs))
// row constructor
case c @ RowConstructor(args) =>
val newArgs = c.elements
.map((exp: Expression) =>
replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
c.makeCopy(Array(newArgs))
// array constructor
case c @ ArrayConstructor(args) =>
val newArgs = c.elements
.map((exp: Expression) =>
replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
c.makeCopy(Array(newArgs))
// map constructor
case c @ MapConstructor(args) =>
val newArgs = c.elements
.map((exp: Expression) =>
replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
c.makeCopy(Array(newArgs))
// General expression
case e: Expression =>
val newArgs = e.productIterator.map {
case arg: Expression =>
replaceAggregationsAndProperties(arg, tableEnv, aggNames, propNames, projectedNames)
}
e.makeCopy(newArgs.toArray)
}
}
/**
* Expands an UnresolvedFieldReference("*") to parent's full project list.
*/
def expandProjectList(
exprs: Seq[Expression],
parent: LogicalNode,
tableEnv: TableEnvironment)
: Seq[Expression] = {
val projectList = new ListBuffer[Expression]
exprs.foreach {
case n: UnresolvedFieldReference if n.name == "*" =>
projectList ++= parent.output.map(a => UnresolvedFieldReference(a.name))
case Flattening(unresolved) =>
// simulate a simple project to resolve fields using current parent
val project = Project(Seq(UnresolvedAlias(unresolved)), parent).validate(tableEnv)
val resolvedExpr = project
.output
.headOption
.getOrElse(throw new RuntimeException("Could not find resolved composite."))
resolvedExpr.validateInput()
val newProjects = resolvedExpr.resultType match {
case ct: RowType =>
(0 until ct.getArity).map { idx =>
projectList += GetCompositeField(unresolved, ct.getFieldNames()(idx))
}
case _ =>
projectList += unresolved
}
case e: Expression => projectList += e
}
projectList
}
def resolveOverWindows(
exprs: Seq[Expression],
overWindows: Array[OverWindow],
tEnv: TableEnvironment): Seq[Expression] = {
exprs.map(e => replaceOverCall(e, overWindows, tEnv))
}
/**
* Find and replace UnresolvedOverCall with OverCall
*
* @param expr the expression to check
* @return an expression with correct resolved OverCall
*/
private def replaceOverCall(
expr: Expression,
overWindows: Array[OverWindow],
tableEnv: TableEnvironment): Expression = {
expr match {
case u: UnresolvedOverCall =>
val overWindow = overWindows.find(_.alias.equals(u.alias))
if (overWindow.isDefined) {
OverCall(
u.agg,
overWindow.get.partitionBy,
overWindow.get.orderBy,
overWindow.get.preceding,
overWindow.get.following,
tableEnv)
} else {
u
}
case u: UnaryExpression =>
val c = replaceOverCall(u.child, overWindows, tableEnv)
u.makeCopy(Array(c))
case b: BinaryExpression =>
val l = replaceOverCall(b.left, overWindows, tableEnv)
val r = replaceOverCall(b.right, overWindows, tableEnv)
b.makeCopy(Array(l, r))
// Functions calls
case c @ Call(name, args: Seq[Expression]) =>
val newArgs =
args.map(
(exp: Expression) =>
replaceOverCall(exp, overWindows, tableEnv))
c.makeCopy(Array(name, newArgs))
// Scala functions
case sfc @ ScalarFunctionCall(clazz, args: Seq[Expression]) =>
val newArgs: Seq[Expression] =
args.map(
(exp: Expression) =>
replaceOverCall(exp, overWindows, tableEnv))
sfc.makeCopy(Array(clazz, newArgs))
// Array constructor
case c @ ArrayConstructor(args) =>
val newArgs =
c.elements
.map((exp: Expression) => replaceOverCall(exp, overWindows, tableEnv))
c.makeCopy(Array(newArgs))
// Other expressions
case e: Expression => e
}
}
/**
* Extract all field references from the given expressions.
*
* @param exprs a list of expressions to extract
* @return a list of field references extracted from the given expressions
*/
def extractFieldReferences(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.foldLeft(Set[NamedExpression]()) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}.toSeq
}
private def identifyFieldReferences(
expr: Expression,
fieldReferences: Set[NamedExpression]): Set[NamedExpression] = expr match {
case f: UnresolvedFieldReference =>
fieldReferences + UnresolvedAlias(f)
case b: BinaryExpression =>
val l = identifyFieldReferences(b.left, fieldReferences)
identifyFieldReferences(b.right, l)
// Functions calls
case Call(_, args: Seq[Expression]) =>
args.foldLeft(fieldReferences) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}
case ScalarFunctionCall(_, args: Seq[Expression]) =>
args.foldLeft(fieldReferences) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}
case AggFunctionCall(_, _, _, args) =>
args.foldLeft(fieldReferences) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}
// array constructor
case ArrayConstructor(args) =>
args.foldLeft(fieldReferences) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}
// ignore fields from window property
case _: WindowProperty =>
fieldReferences
// keep this case after all unwanted unary expressions
case u: UnaryExpression =>
identifyFieldReferences(u.child, fieldReferences)
// General expression
case e: Expression =>
e.productIterator.foldLeft(fieldReferences) {
(fieldReferences, expr) => expr match {
case e: Expression => identifyFieldReferences(e, fieldReferences)
case _ => fieldReferences
}
}
}
/**
* Find and replace UDAGG function Call to AggFunctionCall
*
* @param field the expression to check
* @param tableEnv the TableEnvironment
* @return an expression with correct AggFunctionCall type for UDAGG functions
*/
def replaceAggFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = {
field match {
case l: LeafExpression => l
case u: UnaryExpression =>
val c = replaceAggFunctionCall(u.child, tableEnv)
u.makeCopy(Array(c))
case b: BinaryExpression =>
val l = replaceAggFunctionCall(b.left, tableEnv)
val r = replaceAggFunctionCall(b.right, tableEnv)
b.makeCopy(Array(l, r))
// Functions calls
case c @ Call(name, args) =>
val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
function match {
case a: AggFunctionCall => a
case a: Aggregation => a
case p: AbstractWindowProperty => p
case _ =>
val newArgs =
args.map(
(exp: Expression) =>
replaceAggFunctionCall(exp, tableEnv))
c.makeCopy(Array(name, newArgs))
}
// Scala functions
case sfc @ ScalarFunctionCall(clazz, args) =>
val newArgs: Seq[Expression] =
args.map(
(exp: Expression) =>
replaceAggFunctionCall(exp, tableEnv))
sfc.makeCopy(Array(clazz, newArgs))
// Array constructor
case c @ ArrayConstructor(args) =>
val newArgs =
c.elements
.map((exp: Expression) => replaceAggFunctionCall(exp, tableEnv))
c.makeCopy(Array(newArgs))
// Other expressions
case e: Expression => e
}
}
}