blob: b9b5a3e7fd56e0ad68ef65c24a43fa40a7878a8a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.sql.SemiJoinType
import org.apache.flink.api.common.functions.Function
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.{TableConfig, TableException}
import org.apache.flink.table.codegen.CodeGenUtils.primitiveDefaultValue
import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NO_CODE}
import org.apache.flink.table.codegen._
import org.apache.flink.table.functions.utils.TableSqlFunction
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.TableFunctionCollector
import org.apache.flink.types.Row
import scala.collection.JavaConverters._
/**
* Join a user-defined table function
*/
trait CommonCorrelate {
/**
* Generates the flat map function to run the user-defined table function.
*/
private[flink] def generateFunction[T <: Function](
config: TableConfig,
inputSchema: RowSchema,
udtfTypeInfo: TypeInformation[Any],
returnSchema: RowSchema,
joinType: SemiJoinType,
rexCall: RexCall,
pojoFieldMapping: Option[Array[Int]],
ruleDescription: String,
functionClass: Class[T]):
GeneratedFunction[T, Row] = {
val functionGenerator = new FunctionCodeGenerator(
config,
false,
inputSchema.typeInfo,
Some(udtfTypeInfo),
None,
pojoFieldMapping)
val (input1AccessExprs, input2AccessExprs) = functionGenerator.generateCorrelateAccessExprs
val collectorTerm = functionGenerator
.addReusableConstructor(classOf[TableFunctionCollector[_]])
.head
val call = functionGenerator.generateExpression(rexCall)
var body =
s"""
|${call.resultTerm}.setCollector($collectorTerm);
|${call.code}
|""".stripMargin
if (joinType == SemiJoinType.LEFT) {
// left outer join
// in case of left outer join and the returned row of table function is empty,
// fill all fields of row with null
val input2NullExprs = input2AccessExprs.map { x =>
GeneratedExpression(
primitiveDefaultValue(x.resultType),
ALWAYS_NULL,
NO_CODE,
x.resultType)
}
val outerResultExpr = functionGenerator.generateResultExpression(
input1AccessExprs ++ input2NullExprs,
returnSchema.typeInfo,
returnSchema.fieldNames)
body +=
s"""
|boolean hasOutput = $collectorTerm.isCollected();
|if (!hasOutput) {
| ${outerResultExpr.code}
| ${functionGenerator.collectorTerm}.collect(${outerResultExpr.resultTerm});
|}
|""".stripMargin
} else if (joinType != SemiJoinType.INNER) {
throw new TableException(s"Unsupported SemiJoinType: $joinType for correlate join.")
}
functionGenerator.generateFunction(
ruleDescription,
functionClass,
body,
returnSchema.typeInfo)
}
/**
* Generates table function collector.
*/
private[flink] def generateCollector(
config: TableConfig,
inputSchema: RowSchema,
udtfTypeInfo: TypeInformation[Any],
returnSchema: RowSchema,
condition: Option[RexNode],
pojoFieldMapping: Option[Array[Int]])
: GeneratedCollector = {
val generator = new CollectorCodeGenerator(
config,
false,
inputSchema.typeInfo,
Some(udtfTypeInfo),
None,
pojoFieldMapping)
val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs
val crossResultExpr = generator.generateResultExpression(
input1AccessExprs ++ input2AccessExprs,
returnSchema.typeInfo,
returnSchema.fieldNames)
val filterGenerator = new FunctionCodeGenerator(
config,
false,
udtfTypeInfo,
None,
pojoFieldMapping)
val collectorCode = if (condition.isEmpty) {
s"""
|${crossResultExpr.code}
|getCollector().collect(${crossResultExpr.resultTerm});
|""".stripMargin
} else {
filterGenerator.input1Term = filterGenerator.input2Term
// generating filter expressions might add init statements and member fields
// that need to be combined with the statements of the collector code generator
val filterCondition = filterGenerator.generateExpression(condition.get)
s"""
|${filterGenerator.reuseInputUnboxingCode()}
|${filterCondition.code}
|if (${filterCondition.resultTerm}) {
| ${crossResultExpr.code}
| getCollector().collect(${crossResultExpr.resultTerm});
|}
|""".stripMargin
}
generator.generateTableFunctionCollector(
"TableFunctionCollector",
collectorCode,
udtfTypeInfo,
filterGenerator)
}
private[flink] def selectToString(rowType: RelDataType): String = {
rowType.getFieldNames.asScala.mkString(", ")
}
private[flink] def correlateOpName(
inputType: RelDataType,
rexCall: RexCall,
sqlFunction: TableSqlFunction,
rowType: RelDataType,
expression: (RexNode, List[String], Option[List[RexNode]]) => String)
: String = {
s"correlate: ${correlateToString(inputType, rexCall, sqlFunction, expression)}," +
s" select: ${selectToString(rowType)}"
}
private[flink] def correlateToString(
inputType: RelDataType,
rexCall: RexCall,
sqlFunction: TableSqlFunction,
expression: (RexNode, List[String], Option[List[RexNode]]) => String): String = {
val inFields = inputType.getFieldNames.asScala.toList
val udtfName = sqlFunction.toString
val operands = rexCall.getOperands.asScala.map(expression(_, inFields, None)).mkString(", ")
s"table($udtfName($operands))"
}
}