blob: 762f7bc4c82fd9edae761b58e94b3ca06778f411 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.codegen
import java.util.{ArrayList => JArrayList, Collection => JCollection}
import org.apache.calcite.rex.{RexLiteral, RexNode, RexProgram}
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.functions.async.{AsyncFunction, ResultFuture}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableConfig, TableConfigOptions}
import org.apache.flink.table.api.functions.{AsyncTableFunction, TableFunction}
import org.apache.flink.table.api.types.{DataTypes, InternalType, RowType}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.dataformat.{BaseRow, GenericRow, JoinedRow}
import org.apache.flink.table.plan.schema.BaseRowSchema
import org.apache.flink.table.runtime.collector.TableFunctionCollector
import org.apache.flink.table.runtime.conversion.DataStructureConverters.RowConverter
import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
object TemporalJoinCodeGenerator {
/**
* Generates a lookup function ([[TableFunction]])
*/
def generateLookupFunction(
config: TableConfig,
typeFactory: FlinkTypeFactory,
inputType: InternalType,
returnType: InternalType,
tableReturnTypeInfo: TypeInformation[_],
tableReturnClass: Class[_],
lookupKeyInOrder: Array[Int],
lookupKeysFromInput: Map[Int, Int], // lookup key index -> input field index
lookupKeysFromConstant: Map[Int, RexLiteral], // lookup key index -> constant value
lookupFunction: TableFunction[_],
enableObjectReuse: Boolean)
: GeneratedFunction[FlatMapFunction[BaseRow, BaseRow], BaseRow] = {
val ctx = CodeGeneratorContext(config)
val (prepareCode, parameters) = prepareParameters(
ctx,
config,
typeFactory,
inputType,
lookupKeyInOrder,
lookupKeysFromInput,
lookupKeysFromConstant,
enableObjectReuse)
val lookupFunctionTerm = ctx.addReusableFunction(lookupFunction)
val setCollectorCode = if (tableReturnClass == classOf[Row]) {
val converterCollector =
new RowToBaseRowCollector(tableReturnTypeInfo.asInstanceOf[RowTypeInfo])
val term = ctx.addReusableObject(converterCollector, "collector")
s"""
|$term.setCollector(${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM});
|$lookupFunctionTerm.setCollector($term);
""".stripMargin
} else {
s"$lookupFunctionTerm.setCollector(${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM});"
}
val body =
s"""
|$prepareCode
|$setCollectorCode
|$lookupFunctionTerm.eval($parameters);
""".stripMargin
FunctionCodeGenerator.generateFunction(
ctx,
"LookupFunction",
classOf[FlatMapFunction[BaseRow, BaseRow]],
body,
returnType,
inputType,
config)
}
/**
* Generates a async lookup function ([[AsyncTableFunction]])
*/
def generateAsyncLookupFunction(
config: TableConfig,
typeFactory: FlinkTypeFactory,
inputType: InternalType,
returnType: InternalType,
tableReturnTypeInfo: TypeInformation[_],
tableReturnClass: Class[_],
lookupKeyInOrder: Array[Int],
lookupKeysFromInput: Map[Int, Int], // lookup key index -> input field index
lookupKeysFromConstant: Map[Int, RexLiteral],
asyncLookupFunction: AsyncTableFunction[_])
: GeneratedFunction[AsyncFunction[BaseRow, BaseRow], BaseRow] = {
val ctx = CodeGeneratorContext(config)
val (prepareCode, parameters) = prepareParameters(
ctx,
config,
typeFactory,
inputType,
lookupKeyInOrder,
lookupKeysFromInput,
lookupKeysFromConstant,
fieldCopy = true) // always copy input field because of async buffer
val lookupFunctionTerm = ctx.addReusableFunction(asyncLookupFunction)
var futureTerm: String = null
val setFutureCode = if (tableReturnClass == classOf[Row]) {
val converterFuture =
new RowToBaseRowResultFuture(tableReturnTypeInfo.asInstanceOf[RowTypeInfo])
futureTerm = ctx.addReusableObject(converterFuture, "future")
s"$futureTerm.setFuture(${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM});"
} else {
futureTerm = CodeGeneratorContext.DEFAULT_COLLECTOR_TERM
""
}
val body =
s"""
|$prepareCode
|$setFutureCode
|$lookupFunctionTerm.eval($futureTerm, $parameters);
""".stripMargin
FunctionCodeGenerator.generateFunction(
ctx,
"LookupFunction",
classOf[AsyncFunction[BaseRow, BaseRow]],
body,
returnType,
inputType,
config)
}
private def prepareParameters(
ctx: CodeGeneratorContext,
config: TableConfig,
typeFactory: FlinkTypeFactory,
inputType: InternalType,
lookupKeyInOrder: Array[Int],
lookupKeysFromInput: Map[Int, Int], // lookup key index -> input field index
lookupKeysFromConstant: Map[Int, RexLiteral],
fieldCopy: Boolean): (String, String) = {
// the total number of lookupKeys should equal to fromInput plus fromConstant
Preconditions.checkArgument(
lookupKeyInOrder.length == lookupKeysFromInput.size + lookupKeysFromConstant.size)
val inputFieldExprs = for (i <- lookupKeyInOrder) yield {
if (lookupKeysFromInput.contains(i)) {
generateInputAccess(
ctx,
inputType,
CodeGeneratorContext.DEFAULT_INPUT1_TERM,
lookupKeysFromInput(i),
nullableInput = false,
config.getNullCheck,
fieldCopy)
} else if (lookupKeysFromConstant.contains(i)) {
val literal = lookupKeysFromConstant(i)
val resultType = FlinkTypeFactory.toInternalType(literal.getType)
val value = literal.getValue3
generateLiteral(ctx, literal.getType, resultType, value, config.getNullCheck)
} else {
throw new CodeGenException("This should never happen!")
}
}
val codeAndArg = inputFieldExprs
.map { e =>
val bType = boxedTypeTermForType(e.resultType)
val newTerm = newName("arg")
val code =
s"""
|$bType $newTerm = null;
|if (!${e.nullTerm}) {
| $newTerm = ${e.resultTerm};
|}
""".stripMargin
(code, newTerm)
}
(codeAndArg.map(_._1).mkString("\n"), codeAndArg.map(_._2).mkString(", "))
}
/**
* Generates async collector for async temporal join ([[ResultFuture]])
*/
def generateAsyncCollector(
config: TableConfig,
inputType: RowType,
tableType: RowType,
joinCondition: Option[RexNode]): GeneratedCollector = {
val inputTerm = CodeGeneratorContext.DEFAULT_INPUT1_TERM
val tableInputTerm = CodeGeneratorContext.DEFAULT_INPUT2_TERM
val ctx = CodeGeneratorContext(config)
val exprGenerator = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(tableType, inputTerm = tableInputTerm)
val tableResultExpr = exprGenerator.generateConverterResultExpression(
tableType, classOf[GenericRow])
val body =
s"""
|${tableResultExpr.code}
|getCollector().complete(java.util.Collections.singleton(${tableResultExpr.resultTerm}));
""".stripMargin
val collectorCode = if (joinCondition.isEmpty) {
body
} else {
val filterGenerator = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(inputType, inputTerm)
.bindSecondInput(tableType, tableInputTerm)
val filterCondition = filterGenerator.generateExpression(joinCondition.get)
s"""
|${filterCondition.code}
|if (${filterCondition.resultTerm}) {
| $body
|} else {
| getCollector().complete(java.util.Collections.emptyList());
|}
|""".stripMargin
}
CollectorCodeGenerator.generateTableAsyncCollector(
ctx,
"TableAsyncCollector",
collectorCode,
inputType,
tableType,
config)
}
/**
* Generates collector for temporal join ([[Collector]])
*
* Differs from CommonCorrelate.generateCollector which has no real condition because of
* FLINK-7865, here we should deal with outer join type when real conditions filtered result.
*/
def generateCollector(
ctx: CodeGeneratorContext,
config: TableConfig,
inputType: RowType,
udtfTypeInfo: RowType,
resultType: RowType,
condition: Option[RexNode],
pojoFieldMapping: Option[Array[Int]],
retainHeader: Boolean = true): GeneratedCollector = {
val inputTerm = CodeGeneratorContext.DEFAULT_INPUT1_TERM
val udtfInputTerm = CodeGeneratorContext.DEFAULT_INPUT2_TERM
val exprGenerator = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(udtfTypeInfo, inputTerm = udtfInputTerm, inputFieldMapping = pojoFieldMapping)
val udtfResultExpr = exprGenerator.generateConverterResultExpression(
udtfTypeInfo, classOf[GenericRow])
val joinedRowTerm = CodeGenUtils.newName("joinedRow")
ctx.addOutputRecord(resultType, classOf[JoinedRow], joinedRowTerm)
val header = if (retainHeader) {
s"$joinedRowTerm.setHeader($inputTerm.getHeader());"
} else {
""
}
val body =
s"""
|${udtfResultExpr.code}
|$joinedRowTerm.replace($inputTerm, ${udtfResultExpr.resultTerm});
|$header
|getCollector().collect($joinedRowTerm);
|super.collect(record);
""".stripMargin
val collectorCode = if (condition.isEmpty) {
body
} else {
val filterGenerator = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(inputType, inputTerm)
.bindSecondInput(udtfTypeInfo, udtfInputTerm, pojoFieldMapping)
val filterCondition = filterGenerator.generateExpression(condition.get)
s"""
|${filterCondition.code}
|if (${filterCondition.resultTerm}) {
| $body
|}
|""".stripMargin
}
generateTableFunctionCollectorForJoinTable(
ctx,
"JoinTableFuncCollector",
collectorCode,
inputType,
udtfTypeInfo,
config,
inputTerm = inputTerm,
collectedTerm = udtfInputTerm)
}
/**
* The only differences against CollectorCodeGenerator.generateTableFunctionCollector is
* "super.collect" call is binding with collect join row in "body" code
*/
private def generateTableFunctionCollectorForJoinTable(
ctx: CodeGeneratorContext,
name: String,
bodyCode: String,
inputType: RowType,
collectedType: RowType,
config: TableConfig,
inputTerm: String = CodeGeneratorContext.DEFAULT_INPUT1_TERM,
collectedTerm: String = CodeGeneratorContext.DEFAULT_INPUT2_TERM)
: GeneratedCollector = {
val className = newName(name)
val input1TypeClass = boxedTypeTermForType(inputType)
val input2TypeClass = boxedTypeTermForType(collectedType)
val unboxingCodeSplit = generateSplitFunctionCalls(
ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
"inputUnbox",
"private final void",
ctx.reuseFieldCode().length,
defineParams = s"$input1TypeClass $inputTerm, $input2TypeClass $collectedTerm",
callingParams = s"$inputTerm, $collectedTerm"
)
val funcCode = if (unboxingCodeSplit.isSplit) {
s"""
public class $className extends ${classOf[TableFunctionCollector[_]].getCanonicalName} {
${ctx.reuseMemberCode()}
${ctx.reuseFieldCode()}
public $className() throws Exception {
${ctx.reuseInitCode()}
}
@Override
public void collect(Object record) throws Exception {
$input1TypeClass $inputTerm = ($input1TypeClass) getInput();
$input2TypeClass $collectedTerm = ($input2TypeClass) record;
${unboxingCodeSplit.callings.mkString("\n")}
$bodyCode
}
${unboxingCodeSplit.definitions.zip(unboxingCodeSplit.bodies) map {
case (define, body) =>
s"""
|$define throws Exception {
| $body
|}
""".stripMargin
} mkString "\n"
}
@Override
public void close() {
}
}
""".stripMargin
} else {
s"""
public class $className extends ${classOf[TableFunctionCollector[_]].getCanonicalName} {
${ctx.reuseMemberCode()}
public $className() throws Exception {
${ctx.reuseInitCode()}
}
@Override
public void collect(Object record) throws Exception {
$input1TypeClass $inputTerm = ($input1TypeClass) getInput();
$input2TypeClass $collectedTerm = ($input2TypeClass) record;
${ctx.reuseFieldCode()}
${ctx.reuseInputUnboxingCode()}
$bodyCode
}
@Override
public void close() {
}
}
""".stripMargin
}
GeneratedCollector(className, funcCode)
}
/**
* Genrates calculate flatmap function for temporal join which is used
* to projection/filter the dimension table results
*/
def generateCalcMapFunction(
config: TableConfig,
calcProgram: Option[RexProgram],
tableSourceSchema: BaseRowSchema)
: GeneratedFunction[FlatMapFunction[BaseRow, BaseRow], BaseRow] = {
val program = calcProgram.get
val condition = if (program.getCondition != null) {
Some(program.expandLocalRef(program.getCondition))
} else {
None
}
CalcCodeGenerator.generateFunction(
tableSourceSchema.internalType,
"TableCalcMapFunction",
FlinkTypeFactory.toInternalRowType(program.getOutputRowType),
classOf[GenericRow],
program,
condition,
config,
classOf[FlatMapFunction[BaseRow, BaseRow]])
}
// ----------------------------------------------------------------------------------------
// Utility Classes
// ----------------------------------------------------------------------------------------
class RowToBaseRowCollector(rowTypeInfo: RowTypeInfo)
extends TableFunctionCollector[Row] with Serializable {
private val converter =
RowConverter(rowTypeInfo.toInternalType.asInstanceOf[RowType])
override def collect(record: Row): Unit = {
super.collect(record)
val result = converter.toInternalImpl(record)
getCollector.asInstanceOf[Collector[BaseRow]].collect(result)
}
override def reset(): Unit = {
super.reset()
getCollector.asInstanceOf[TableFunctionCollector[_]].reset()
}
override def close(): Unit = getCollector.close()
}
class RowToBaseRowResultFuture(rowTypeInfo: RowTypeInfo)
extends ResultFuture[Row] with Serializable {
private val converter =
RowConverter(rowTypeInfo.toInternalType.asInstanceOf[RowType])
private var future: ResultFuture[BaseRow] = _
def setFuture(future: ResultFuture[BaseRow]): Unit = {
this.future = future
}
override def complete(result: JCollection[Row]): Unit = {
if (result == null) {
this.future.complete(null)
} else {
val baseRowResult = new JArrayList[BaseRow]
val iter = result.iterator()
while (iter.hasNext) {
baseRowResult.add(converter.toInternalImpl(iter.next()))
}
this.future.complete(baseRowResult)
}
}
override def completeExceptionally(error: Throwable): Unit = {
this.future.completeExceptionally(error)
}
}
}