blob: 7319bc84183673ecbeb8c8a4727471a1e08e201e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.codegen
import java.lang.{Long => JLong}
import org.apache.calcite.rel.RelCollation
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rex._
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.functions.Function
import org.apache.flink.cep.pattern.conditions.{IterativeCondition, RichIterativeCondition}
import org.apache.flink.cep._
import org.apache.flink.configuration.Configuration
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableConfig, TableConfigOptions, TableException}
import org.apache.flink.table.api.types._
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.agg.AggsHandlerCodeGenerator
import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.codegen.MatchCodeGenerator._
import org.apache.flink.table.plan.schema.BaseRowSchema
import org.apache.flink.table.plan.util.AggregateUtil
import org.apache.flink.table.dataformat.{BaseRow, GenericRow}
import org.apache.flink.table.functions.sql.ProctimeSqlFunction
import org.apache.flink.table.plan.util.MatchUtil.AggregationPatternVariableFinder
import org.apache.flink.table.runtime.conversion.DataStructureConverters.genToInternal
import org.apache.flink.table.runtime.functions.{AggsHandleFunction, ExecutionContextImpl}
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.util.Collector
import org.apache.flink.util.MathUtils.checkedDownCast
import _root_.scala.collection.JavaConversions._
import _root_.scala.collection.JavaConverters._
import _root_.scala.collection.mutable
/**
* A code generator for generating CEP related functions.
*
* @param ctx the cotext of the code generator
* @param nullableInput input(s) can be null.
* @param nullCheck whether to do null check
* @param patternNames sorted sequence of pattern variables
* @param currentPattern if generating condition the name of pattern, which the condition will
* be applied to
*/
class MatchCodeGenerator(
ctx: CodeGeneratorContext,
relBuilder: RelBuilder,
nullableInput: Boolean,
nullCheck: Boolean,
patternNames: Seq[String],
currentPattern: Option[String] = None,
collectorTerm: String = CodeGeneratorContext.DEFAULT_COLLECTOR_TERM)
extends ExprCodeGenerator(ctx, nullableInput, nullCheck) {
private case class GeneratedPatternList(resultTerm: String, code: String, filled: Boolean = true)
private case class GeneratedClassifierList(
resultTerm: String, code: String, filled: Boolean = true)
/**
* Used to assign unique names for list of events per pattern variable name. Those lists
* are treated as inputs and are needed by input access code.
*/
private val reusablePatternLists: mutable.HashMap[(Boolean, String), GeneratedPatternList] =
mutable.HashMap[(Boolean, String), GeneratedPatternList]()
private val reusableClassiferList: mutable.HashMap[Boolean, GeneratedClassifierList] =
mutable.HashMap[Boolean, GeneratedClassifierList]()
private val reusableInputUnboxingExprs: mutable.Map[(String, Int), GeneratedExpression] =
mutable.Map[(String, Int), GeneratedExpression]()
private val reusablePerRecordStatements: mutable.LinkedHashSet[String] =
mutable.LinkedHashSet[String]()
/**
* Used to deduplicate aggregations calculation. The deduplication is performed by
* [[RexNode#toString]]. Those expressions needs to be accessible from splits, if such exists.
*/
private val reusableAggregationExpr = new mutable.HashMap[String, GeneratedExpression]()
/**
* Context information used by Pattern reference variable to index rows mapped to it.
* Indexes element at offset either from beginning or the end based on the value of first.
*/
private var offset: Int = 0
private var first : Boolean = false
private var oneRowPerMatch: Boolean = false
private var running: Boolean = false
/**
* Flags that tells if we generate expressions inside an aggregate. It tells how to access input
* row.
*/
private var isWithinAggExprState: Boolean = false
/**
* Used to collect all aggregates per pattern variable.
*/
private val aggregatesPerVariable = new mutable.HashMap[(Boolean, String), AggBuilder]
/**
* Name of term in function used to transform input row into aggregate input row.
*/
private val inputAggRowTerm = "inAgg"
private val keyRowTerm = "keyRow"
/**
* @return term of pattern names
*/
private val patternNamesTerm = newName("patternNames")
private lazy val eventTypeTerm = boxedTypeTermForType(input1Type)
/**
* Sets the new reference variable indexing context. This should be used when resolving logical
* offsets = LAST/FIRST
*
* @param first true if indexing from the beginning, false otherwise
* @param offset offset from either beginning or the end
*/
private def updateOffsets(first: Boolean, offset: Int): Unit = {
this.first = first
this.offset = offset
}
/** Resets indexing context of Pattern variable. */
private def resetOffsets(): Unit = {
first = false
offset = 0
}
private def setRunning(running: Boolean): Unit = {
this.running = running
}
private def resetRunning(): Unit = {
this.running = false
}
private def setOneRowPerMatch(oneRowPerMatch: Boolean): Unit = {
this.oneRowPerMatch = oneRowPerMatch
}
private def reusePatternLists(): String = {
reusablePatternLists.values.map(_.code).mkString("\n")
}
private def reuseClassifierLists(): String = {
reusableClassiferList.values.map(_.code).mkString("\n")
}
private def reuseInputUnboxingCode(): String = {
reusableInputUnboxingExprs.values.map(_.code).mkString("\n")
}
private def reusePerRecordCode(): String = {
reusablePerRecordStatements.mkString("\n")
}
private def addReusablePatternNames(): Unit = {
ctx.addReusableMember(s"private String[] $patternNamesTerm = new String[] { ${
patternNames.map(p => s""""${EncodingUtils.escapeJava(p)}"""").mkString(", ")
} };")
}
/**
* Generates a [[org.apache.flink.api.common.functions.Function]] that can be passed to Java
* compiler.
*
* @param name Class name of the Function. Must not be unique but has to be a valid Java class
* identifier.
* @param clazz Flink Function to be generated.
* @param bodyCode code contents of the SAM (Single Abstract Method). Inputs, collector, or
* output record can be accessed via the given term methods.
* @tparam F Flink Function to be generated.
* @tparam T Return type of the Flink Function.
* @return instance of GeneratedFunction
*/
def generateMatchFunction[F <: Function, T <: Any](
name: String,
config: TableConfig,
clazz: Class[F],
bodyCode: String)
: GeneratedClass[_] = {
val funcName = newName(name)
val collectorTypeTerm = classOf[Collector[Any]].getCanonicalName
val inputTypeTerm = boxedTypeTermForType(input1Type)
val (functionClass, signature, inputStatements, unboxingCodeSplit) =
if (clazz == classOf[RichIterativeCondition[_]]) {
val baseClass = classOf[RichIterativeCondition[_]]
val contextType = classOf[IterativeCondition.Context[_]].getCanonicalName
val unboxingCodeSplit = generateSplitFunctionCalls(
ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
"inputUnbox",
"private final void",
ctx.reuseFieldCode().length,
defineParams = s"$inputTypeTerm $input1Term, " +
s"${classOf[IterativeCondition.Context[_]].getCanonicalName} $contextTerm",
callingParams = s"$input1Term, $contextTerm"
)
(baseClass,
s"boolean filter(Object _in1, $contextType $contextTerm)",
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"),
unboxingCodeSplit)
} else if (clazz == classOf[RichPatternSelectFunction[_, _]]) {
val baseClass = classOf[RichPatternSelectFunction[_, _]]
val unboxingCodeSplit = generateSplitFunctionCalls(
ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
"inputUnbox",
"private final void",
ctx.reuseFieldCode().length,
defineParams = s"java.util.Map<String, java.util.List<$inputTypeTerm>> $input1Term",
callingParams = input1Term
)
(baseClass,
s"Object select(java.util.Map<String, java.util.List<$inputTypeTerm>> $input1Term)",
List(),
unboxingCodeSplit)
} else if (clazz == classOf[RichPatternFlatSelectFunction[_, _]]) {
val baseClass = classOf[RichPatternFlatSelectFunction[_, _]]
val unboxingCodeSplit = generateSplitFunctionCalls(
ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
"inputUnbox",
"private final void",
ctx.reuseFieldCode().length,
defineParams = s"java.util.Map<String, java.util.List<$inputTypeTerm>> $input1Term, " +
s"$collectorTypeTerm $collectorTerm",
callingParams = s"$input1Term, $collectorTerm"
)
(baseClass,
s"void flatSelect(java.util.Map<String, java.util.List<$inputTypeTerm>> $input1Term, " +
s"$collectorTypeTerm $collectorTerm)",
List(),
unboxingCodeSplit)
} else if (clazz == classOf[RichPatternTimeoutFunction[_, _]]) {
val baseClass = classOf[RichPatternTimeoutFunction[_, _]]
val unboxingCodeSplit = generateSplitFunctionCalls(
ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
"inputUnbox",
"private final void",
ctx.reuseFieldCode().length,
defineParams = s"java.util.Map<String, java.util.List<$inputTypeTerm>> $input1Term, " +
s"long timeoutTimestamp",
callingParams = s"$input1Term, timeoutTimestamp"
)
(baseClass,
s"Object timeout(java.util.Map<String, java.util.List<Object>> $input1Term, " +
"long timeoutTimestamp)",
List(),
unboxingCodeSplit)
} else if (clazz == classOf[RichPatternFlatTimeoutFunction[_, _]]) {
val baseClass = classOf[RichPatternFlatTimeoutFunction[_, _]]
val unboxingCodeSplit = generateSplitFunctionCalls(
ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
"inputUnbox",
"private final void",
ctx.reuseFieldCode().length,
defineParams = s"java.util.Map<String, java.util.List<$inputTypeTerm>> $input1Term, " +
s"$collectorTypeTerm $collectorTerm",
callingParams = s"$input1Term, $collectorTerm"
)
(baseClass,
s"void timeout(java.util.Map<String, java.util.List<Object>> $input1Term, " +
s"long timeoutTimestamp, $collectorTypeTerm $collectorTerm)",
List(),
unboxingCodeSplit)
} else {
throw new CodeGenException("Unsupported Function.")
}
val funcCode = if (unboxingCodeSplit.isSplit) {
j"""
public class $funcName extends ${functionClass.getCanonicalName} {
${ctx.reuseMemberCode()}
${ctx.reuseFieldCode()}
public $funcName(Object[] references) throws Exception {
${ctx.reuseInitCode()}
}
@Override
public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception {
${ctx.reuseOpenCode()}
}
@Override
public $signature throws Exception {
${inputStatements.mkString("\n")}
${reusePatternLists()}
${reuseClassifierLists()}
${unboxingCodeSplit.callings.mkString("\n")}
${ctx.reusePerRecordCode()}
$bodyCode
}
${
unboxingCodeSplit.definitions.zip(unboxingCodeSplit.bodies) map {
case (define, body) =>
s"""
|$define throws Exception {
| ${ctx.reusePerRecordCode()}
| $body
|}
""".stripMargin
} mkString "\n"
}
@Override
public void close() throws Exception {
${ctx.reuseCloseCode()}
}
}
""".stripMargin
} else {
j"""
public class $funcName extends ${functionClass.getCanonicalName} {
${ctx.reuseMemberCode()}
public $funcName(Object[] references) throws Exception {
${ctx.reuseInitCode()}
}
@Override
public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception {
${ctx.reuseOpenCode()}
}
@Override
public $signature throws Exception {
${inputStatements.mkString("\n")}
${reusePatternLists()}
${reuseClassifierLists()}
${ctx.reusePerRecordCode()}
${ctx.reuseFieldCode()}
${ctx.reuseInputUnboxingCode()}
$bodyCode
}
@Override
public void close() throws Exception {
${ctx.reuseCloseCode()}
}
}
""".stripMargin
}
if (clazz == classOf[RichIterativeCondition[_]]) {
GeneratedIterativeCondition(funcName, funcCode, ctx.references.toArray)
} else if (clazz == classOf[RichPatternSelectFunction[_, _]]) {
GeneratedPatternSelectFunction(funcName, funcCode, ctx.references.toArray)
} else if (clazz == classOf[RichPatternFlatSelectFunction[_, _]]) {
GeneratedPatternFlatSelectFunction(funcName, funcCode, ctx.references.toArray)
} else if (clazz == classOf[RichPatternTimeoutFunction[_, _]]) {
GeneratedPatternTimeoutFunction(funcName, funcCode, ctx.references.toArray)
} else if (clazz == classOf[RichPatternFlatTimeoutFunction[_, _]]) {
GeneratedPatternFlatTimeoutFunction(funcName, funcCode, ctx.references.toArray)
} else {
throw new CodeGenException("Unsupported Function.")
}
}
def generateOneRowPerMatchExpression(
partitionKeys: java.util.List[RexNode],
measures: java.util.Map[String, RexNode],
returnSchema: BaseRowSchema): GeneratedExpression = {
setOneRowPerMatch(oneRowPerMatch = true)
// For "ONE ROW PER MATCH", the output columns include:
// 1) the partition columns;
// 2) the columns defined in the measures clause.
val resultExprs =
partitionKeys.asScala.map { case inputRef: RexInputRef =>
generatePartitionKeyAccess(inputRef)
} ++ returnSchema.fieldNames.filter(measures.containsKey(_)).map { fieldName =>
generateExpression(measures.get(fieldName))
}
val resultCodeGenerator = new ExprCodeGenerator(ctx, nullableInput, nullCheck)
.bindInput(input1Type, inputTerm = input1Term)
val resultExpression = resultCodeGenerator.generateResultExpression(
resultExprs,
new RowType(
returnSchema.fieldTypeInfos,
returnSchema.fieldNames.toArray),
classOf[GenericRow])
aggregatesPerVariable.values.foreach(_.generateAggFunction())
resultExpression
}
def generateAllRowsPerMatchExpression(
partitionKeys: java.util.List[RexNode],
orderKeys: RelCollation,
measures: java.util.Map[String, RexNode],
returnSchema: BaseRowSchema): GeneratedExpression = {
val patternNameTerm = newName("patternName")
val eventNameTerm = newName("event")
val eventNameListTerm = newName("eventList")
val listTypeTerm = classOf[java.util.List[_]].getCanonicalName
setOneRowPerMatch(oneRowPerMatch = false)
// For "ALL ROWS PER MATCH", the output columns include:
// 1) the partition columns;
// 2) the ordering columns;
// 3) the columns defined in the measures clause;
// 4) any remaining columns defined of the input.
val fieldsAccessed = mutable.Set[Int]()
val resultExprs =
partitionKeys.asScala.map { case inputRef: RexInputRef =>
fieldsAccessed += inputRef.getIndex
generateFieldAccess(ctx, input1Type, eventNameTerm, inputRef.getIndex, nullCheck)
} ++ orderKeys.getFieldCollations.asScala.map { fieldCollation =>
fieldsAccessed += fieldCollation.getFieldIndex
generateFieldAccess(ctx, input1Type, eventNameTerm, fieldCollation.getFieldIndex, nullCheck)
} ++ (0 until TypeUtils.getArity(input1Type)).filterNot(fieldsAccessed.contains).map { idx =>
generateFieldAccess(ctx, input1Type, eventNameTerm, idx, nullCheck)
} ++ returnSchema.fieldNames.filter(measures.containsKey(_)).map { fieldName =>
generateExpression(measures.get(fieldName))
}
val resultCodeGenerator = new ExprCodeGenerator(ctx, nullableInput, nullCheck)
.bindInput(input1Type, inputTerm = input1Term)
val resultExpression = resultCodeGenerator.generateResultExpression(
resultExprs,
new RowType(
returnSchema.fieldTypeInfos,
returnSchema.fieldNames.toArray),
classOf[GenericRow])
val resultCode = {
addReusablePatternNames()
def fillPatternLists(): String = {
reusablePatternLists.filterNot(_._2.filled).map { patternList =>
val patternName = patternList._1._2
val listName = patternList._2.resultTerm
if (patternName == ALL_PATTERN_VARIABLE) {
j"""
|$listName.add($eventNameTerm);
|""".stripMargin
} else {
val escapedPatternName = EncodingUtils.escapeJava(patternName)
j"""
|if ($patternNameTerm.equals("$escapedPatternName")) {
| $listName.add($eventNameTerm);
|}
|""".stripMargin
}
}.mkString("\n")
}
def fillClassifierLists(): String = {
reusableClassiferList.filterNot(_._2.filled).map { classiferList =>
val listName = classiferList._2.resultTerm
j"""
|$listName.add($patternNameTerm);
|""".stripMargin
}.mkString("\n")
}
j"""
|for (String $patternNameTerm : $patternNamesTerm) {
| $listTypeTerm $eventNameListTerm = ($listTypeTerm) $input1Term.get($patternNameTerm);
| if ($eventNameListTerm != null) {
| for ($eventTypeTerm $eventNameTerm : $eventNameListTerm) {
| ${fillPatternLists()}
| ${fillClassifierLists()}
| ${reuseInputUnboxingCode()}
| ${reusePerRecordCode()}
| ${resultExpression.code}
| $collectorTerm.collect(${resultExpression.resultTerm});
| }
| }
|}
|""".stripMargin
}
aggregatesPerVariable.values.foreach(_.generateAggFunction())
GeneratedExpression("", "false", resultCode, null)
}
def generateCondition(call: RexNode): GeneratedExpression = {
val exp = call.accept(this)
aggregatesPerVariable.values.foreach(_.generateAggFunction())
exp
}
override def visitCall(call: RexCall): GeneratedExpression = {
call.getOperator match {
case PREV | NEXT =>
val countLiteral = call.getOperands.get(1).asInstanceOf[RexLiteral]
val count = checkedDownCast(countLiteral.getValueAs(classOf[JLong]))
if (count != 0) {
throw new TableException("Flink does not support physical offsets within partition.")
} else {
updateOffsets(first = false, 0)
val exp = call.getOperands.get(0).accept(this)
resetOffsets()
exp
}
case FIRST | LAST =>
val countLiteral = call.getOperands.get(1).asInstanceOf[RexLiteral]
val offset = checkedDownCast(countLiteral.getValueAs(classOf[JLong]))
updateOffsets(call.getOperator == FIRST, offset)
val expr = call.operands.get(0).accept(this)
resetOffsets()
expr
case CLASSIFIER => findClassifierByLogicalPosition()
case RUNNING =>
if (oneRowPerMatch) {
// running is the same as final
call.getOperands.get(0).accept(this)
} else {
setRunning(true)
val expr = call.operands.get(0).accept(this)
resetRunning()
expr
}
case FINAL => call.getOperands.get(0).accept(this)
case _: SqlAggFunction =>
val variable = call.accept(new AggregationPatternVariableFinder)
.getOrElse(throw new TableException("No pattern variable specified in aggregate"))
val matchAgg = aggregatesPerVariable.get((running, variable)) match {
case Some(agg) => agg
case None =>
val agg = new AggBuilder(variable)
aggregatesPerVariable((running, variable)) = agg
agg
}
matchAgg.generateDeduplicatedAggAccess(call)
case ProctimeSqlFunction =>
MatchCodeGenerator.generateProctimeTimestamp()
case _ => super.visitCall(call)
}
}
/**
* Extracts partition keys from any element of the match
*
* @param partitionKey partition key to be extracted
* @return generated code for the given key
*/
private def generatePartitionKeyAccess(partitionKey: RexInputRef): GeneratedExpression = {
val keyRow = generateKeyRow()
generateFieldAccess(ctx, keyRow.resultType, keyRow.resultTerm, partitionKey.getIndex, nullCheck)
}
private def generateKeyRow(): GeneratedExpression = {
val exp = ctx.getReusableInputUnboxingExprs(keyRowTerm, 0) match {
case Some(expr) =>
expr
case None =>
val nullTerm = newName("isNull")
ctx.addReusableMember(s"$eventTypeTerm $keyRowTerm;")
val keyCode =
j"""
|boolean $nullTerm = true;
|for (java.util.Map.Entry entry : $input1Term.entrySet()) {
| java.util.List value = (java.util.List) entry.getValue();
| if (value != null && value.size() > 0) {
| $keyRowTerm = ($eventTypeTerm) value.get(0);
| $nullTerm = false;
| break;
| }
|}
|""".stripMargin
val exp = GeneratedExpression(keyRowTerm, nullTerm, keyCode, input1Type)
ctx.addReusableInputUnboxingExprs(keyRowTerm, 0, exp)
exp
}
exp.copy(code = NO_CODE)
}
override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
if (isWithinAggExprState) {
generateFieldAccess(ctx, input1Type, inputAggRowTerm, fieldRef.getIndex, nullCheck)
} else {
if (fieldRef.getAlpha.equals(ALL_PATTERN_VARIABLE) &&
currentPattern.isDefined && offset == 0 && !first) {
generateInputAccess(
ctx, input1Type, input1Term, fieldRef.getIndex, nullableInput, nullCheck)
} else {
generatePatternFieldRef(fieldRef)
}
}
}
private def generateDefinePatternVariableExp(
patternName: String,
currentPattern: String)
: GeneratedPatternList = {
val Seq(listName, eventNameTerm) = newNames(Seq("patternEvents", "event"))
ctx.addReusableMember(s"java.util.List $listName;")
val addCurrent = if (currentPattern == patternName || patternName == ALL_PATTERN_VARIABLE) {
j"""
|$listName.add($input1Term);
|""".stripMargin
} else {
""
}
val listCode = if (patternName == ALL_PATTERN_VARIABLE) {
addReusablePatternNames()
val patternTerm = newName("pattern")
j"""
|$listName = new java.util.ArrayList();
|for (String $patternTerm : $patternNamesTerm) {
| for ($eventTypeTerm $eventNameTerm :
| $contextTerm.getEventsForPattern($patternTerm)) {
| $listName.add($eventNameTerm);
| }
|}
|""".stripMargin
} else {
val escapedPatternName = EncodingUtils.escapeJava(patternName)
j"""
|$listName = new java.util.ArrayList();
|for ($eventTypeTerm $eventNameTerm :
| $contextTerm.getEventsForPattern("$escapedPatternName")) {
| $listName.add($eventNameTerm);
|}
|""".stripMargin
}
val code =
j"""
|$listCode
|$addCurrent
|""".stripMargin
GeneratedPatternList(listName, code)
}
private def generateMeasurePatternVariableExp(patternName: String): GeneratedPatternList = {
val Seq(listName, patternTerm) = newNames(Seq("patternEvents", "pattern"))
ctx.addReusableMember(s"java.util.List $listName;")
val (code, filled) = if (running) {
val code =
j"""
|$listName = new java.util.ArrayList();
|""".stripMargin
(code, false)
} else if (patternName == ALL_PATTERN_VARIABLE) {
addReusablePatternNames()
val code =
j"""
|$listName = new java.util.ArrayList();
|for (String $patternTerm : $patternNamesTerm) {
| java.util.List rows = (java.util.List) $input1Term.get($patternTerm);
| if (rows != null) {
| $listName.addAll(rows);
| }
|}
|""".stripMargin
(code, true)
} else {
val escapedPatternName = EncodingUtils.escapeJava(patternName)
val code =
j"""
|$listName = (java.util.List) $input1Term.get("$escapedPatternName");
|if ($listName == null) {
| $listName = java.util.Collections.emptyList();
|}
|""".stripMargin
(code, true)
}
GeneratedPatternList(listName, code, filled)
}
private def findEventByLogicalPosition(
patternFieldAlpha: String)
: GeneratedExpression = {
val rowNameTerm = newName("row")
val listName = findEventsByPatternName(patternFieldAlpha).resultTerm
val resultIndex = if (first) {
j"""$offset"""
} else {
j"""$listName.size() - $offset - 1"""
}
ctx.addReusableMember(s"$eventTypeTerm $rowNameTerm;")
val funcCode =
j"""
|$rowNameTerm = null;
|if ($listName.size() > $offset) {
| $rowNameTerm = (($eventTypeTerm) $listName.get($resultIndex));
|}
|""".stripMargin
GeneratedExpression(rowNameTerm, "", funcCode, input1Type)
}
private def findEventsByPatternName(
patternFieldAlpha: String): GeneratedPatternList = {
reusablePatternLists.get((running, patternFieldAlpha)) match {
case Some(expr) =>
expr
case None =>
val exp = currentPattern match {
case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
case None => generateMeasurePatternVariableExp(patternFieldAlpha)
}
reusablePatternLists((running, patternFieldAlpha)) = exp
exp
}
}
private def generateDefineClassifierVariableExp(currentPattern: String)
: GeneratedClassifierList = {
val Seq(listName, eventNameTerm) = newNames(Seq("classifiers", "event"))
ctx.addReusableMember(s"java.util.List $listName;")
val addCurrent =
j"""
|$listName.add("$currentPattern");
|""".stripMargin
val listCode = {
val patternNamesToVisit = patternNames.take(patternNames.indexOf(currentPattern) + 1)
for (patternName <- patternNamesToVisit) yield {
val escapedPatternName = EncodingUtils.escapeJava(patternName)
j"""
|for ($eventTypeTerm $eventNameTerm :
| $contextTerm.getEventsForPattern("$escapedPatternName")) {
| $listName.add("${EncodingUtils.escapeJava(patternName)}");
|}
|""".stripMargin
}
}.mkString("\n")
val code =
j"""
|$listName = new java.util.ArrayList();
|$listCode
|$addCurrent
|""".stripMargin
GeneratedClassifierList(listName, code)
}
private def generateMeasureClassifierVariableExp(): GeneratedClassifierList = {
val Seq(listName, patternTerm, eventNameTerm) = newNames(Seq("classifiers", "pattern", "event"))
ctx.addReusableMember(s"java.util.List $listName;")
val (code, filled) = if (!oneRowPerMatch && running) {
val code =
j"""
|$listName = new java.util.ArrayList();
|""".stripMargin
(code, false)
} else {
addReusablePatternNames()
val code =
j"""
|$listName = new java.util.ArrayList();
|for (String $patternTerm : $patternNamesTerm) {
| java.util.List rows = (java.util.List)
| $input1Term.get(${EncodingUtils.escapeJava(patternTerm)});
| if (rows != null) {
| for ($eventTypeTerm $eventNameTerm : rows) {
| $listName.add($patternTerm);
| }
| }
|}
|""".stripMargin
(code, true)
}
GeneratedClassifierList(listName, code, filled)
}
private def findClassifierByLogicalPosition(): GeneratedExpression = {
val Seq(nullTerm, classifierTerm) = newNames(Seq("isNull", "classifier"))
val listName = findClassifiers().resultTerm
val resultIndex = if (first) {
j"""$offset"""
} else {
j"""$listName.size() - $offset - 1"""
}
val resultType = StringType.INSTANCE
val resultTypeTerm = primitiveTypeTermForType(resultType)
val resultInternal = genToInternal(ctx, resultType, s"$listName.get($resultIndex)")
ctx.addReusableMember(s"$resultTypeTerm $classifierTerm;")
val funcCode =
j"""
|boolean $nullTerm = true;
|$classifierTerm = null;
|if ($listName.size() > $offset) {
| $classifierTerm = ($resultTypeTerm) $resultInternal;
| $nullTerm = false;
|}
|""".stripMargin
GeneratedExpression(classifierTerm, nullTerm, funcCode, resultType)
}
private def findClassifiers(): GeneratedClassifierList = {
reusableClassiferList.get(running) match {
case Some(expr) =>
expr
case None =>
val exp = currentPattern match {
case Some(p) => generateDefineClassifierVariableExp(p)
case None => generateMeasureClassifierVariableExp()
}
reusableClassiferList(running) = exp
exp
}
}
private def generatePatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
val escapedAlpha = EncodingUtils.escapeJava(fieldRef.getAlpha)
val patternVariableRef = getReusableInputUnboxingExprs(s"$escapedAlpha#$first", offset) match {
case Some(expr) =>
expr
case None =>
val exp = findEventByLogicalPosition(fieldRef.getAlpha)
addReusableInputUnboxingExprs(s"$escapedAlpha#$first", offset, exp)
exp
}
generateNullableInputFieldAccess(
ctx,
patternVariableRef.resultType,
patternVariableRef.resultTerm,
fieldRef.getIndex,
nullCheck)
}
private def addReusablePerRecordStatement(s: String): Unit = {
if (running) {
reusablePerRecordStatements.add(s)
} else {
ctx.addPerRecordStatement(s)
}
}
private def getReusableInputUnboxingExprs(inputTerm: String, index: Int)
: Option[GeneratedExpression] =
if (running) {
reusableInputUnboxingExprs.get((inputTerm, index))
} else {
ctx.getReusableInputUnboxingExprs(inputTerm, index)
}
private def addReusableInputUnboxingExprs(
inputTerm: String, index: Int, expr: GeneratedExpression): Unit =
if (running) {
reusableInputUnboxingExprs((inputTerm, index)) = expr
} else {
ctx.addReusableInputUnboxingExprs(inputTerm, index, expr)
}
class AggBuilder(variable: String) {
private val aggregates = new mutable.ListBuffer[RexCall]()
private val variableUID = newName("variable")
private val calculateAggFuncName = s"calculateAgg_$variableUID"
def generateDeduplicatedAggAccess(aggCall: RexCall): GeneratedExpression = {
reusableAggregationExpr.get(s"${aggCall.toString}#$running") match {
case Some(expr) =>
expr
case None =>
val exp: GeneratedExpression = generateAggAccess(aggCall)
aggregates += aggCall
reusableAggregationExpr(s"${aggCall.toString}#$running") = exp
addReusablePerRecordStatement(exp.code)
exp.copy(code = NO_CODE)
}
}
private def generateAggAccess(aggCall: RexCall): GeneratedExpression = {
val singleAggResultTerm = newName("result")
val singleAggNullTerm = newName("nullTerm")
val singleAggResultType = FlinkTypeFactory.toTypeInfo(aggCall.`type`).toInternalType
val primitiveSingleAggResultTypeTerm = primitiveTypeTermForType(singleAggResultType)
val boxedSingleAggResultTypeTerm = boxedTypeTermForType(singleAggResultType)
val allAggRowTerm = s"aggRow_$variableUID"
val rowsForVariableCode = findEventsByPatternName(variable)
val codeForAgg =
j"""
|$GENERIC_ROW $allAggRowTerm = $calculateAggFuncName(${rowsForVariableCode.resultTerm});
|""".stripMargin
addReusablePerRecordStatement(codeForAgg)
val defaultValue = primitiveDefaultValue(singleAggResultType)
val codeForSingleAgg = if (nullCheck) {
j"""
|boolean $singleAggNullTerm;
|$primitiveSingleAggResultTypeTerm $singleAggResultTerm;
|if ($allAggRowTerm.getField(${aggregates.size}) != null) {
| $singleAggResultTerm = ($boxedSingleAggResultTypeTerm) $allAggRowTerm
| .getField(${aggregates.size});
| $singleAggNullTerm = false;
|} else {
| $singleAggNullTerm = true;
| $singleAggResultTerm = $defaultValue;
|}
|""".stripMargin
} else {
j"""
|$primitiveSingleAggResultTypeTerm $singleAggResultTerm =
| ($boxedSingleAggResultTypeTerm) $allAggRowTerm.getField(${aggregates.size});
|""".stripMargin
}
addReusablePerRecordStatement(codeForSingleAgg)
GeneratedExpression(singleAggResultTerm, singleAggNullTerm, NO_CODE, singleAggResultType)
}
def generateAggFunction(): Unit = {
val matchAgg = extractAggregatesAndExpressions
val aggCalls = matchAgg.aggregations.map(a => AggregateCall.create(
a.sqlAggFunction,
false,
false,
a.exprIndices,
-1,
a.resultType,
a.sqlAggFunction.getName))
val needRetraction = matchAgg.aggregations.map(_ => false).toArray
val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
val inputRelType = typeFactory.createStructType(
matchAgg.inputExprs.map(_.getType),
matchAgg.inputExprs.indices.map(i => s"TMP$i"))
val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
aggCalls,
inputRelType,
needRetraction,
needInputCount = false,
isStateBackendDataViews = false,
needDistinctInfo = false)
val inputFieldTypes = matchAgg.inputExprs
.map(expr => FlinkTypeFactory.toInternalType(expr.getType))
val aggsHandlerCodeGenerator = new AggsHandlerCodeGenerator(
CodeGeneratorContext(new TableConfig, supportReference = true),
relBuilder,
inputFieldTypes,
needRetract = false,
needMerge = false,
nullCheck = true,
copyInputField = false)
val generatedAggsHandler = aggsHandlerCodeGenerator.generateAggsHandler(
s"AggFunction_$variableUID",
aggInfoList)
val generatedTerm = ctx.addReusableObject(generatedAggsHandler, "generatedAggHandler")
val functionTerm = s"aggregator_$variableUID"
val declareCode = s"private $AGGS_HANDLE_FUNCTION $functionTerm;"
val initCode = s"$functionTerm = ($AGGS_HANDLE_FUNCTION) " +
s"$generatedTerm.newInstance($CURRENT_CLASS_LOADER);"
ctx.addReusableMember(declareCode, initCode)
val transformFuncName = s"transformRowForAgg_$variableUID"
val inputTransform: String = generateAggInputExprEvaluation(
matchAgg.inputExprs,
transformFuncName)
generateAggCalculation(functionTerm, transformFuncName, inputTransform)
}
private def extractAggregatesAndExpressions: MatchAgg = {
val inputRows = new mutable.LinkedHashMap[String, (RexNode, Int)]
val singleAggregates = aggregates.map { aggCall =>
val callsWithIndices = aggCall.operands.asScala.map(innerCall => {
inputRows.get(innerCall.toString) match {
case Some(x) =>
x
case None =>
val callWithIndex = (innerCall, inputRows.size)
inputRows(innerCall.toString) = callWithIndex
callWithIndex
}
})
SingleAggCall(
aggCall.getOperator.asInstanceOf[SqlAggFunction],
aggCall.`type`,
callsWithIndices.map(callsWithIndice => Integer.valueOf(callsWithIndice._2)))
}
MatchAgg(singleAggregates, inputRows.values.map(_._1).toSeq)
}
private def generateAggCalculation(
functionTerm: String,
transformFuncName: String,
inputTransformFunc: String): Unit = {
val code =
j"""
|$inputTransformFunc
|
|private $GENERIC_ROW $calculateAggFuncName(java.util.List input)
| throws Exception {
| $functionTerm.setAccumulators($functionTerm.createAccumulators());
| for ($BASE_ROW row : input) {
| $functionTerm.accumulate($transformFuncName(row));
| }
| $GENERIC_ROW result = ($GENERIC_ROW) $functionTerm.getValue();
| return result;
|}
|""".stripMargin
ctx.addReusableMember(code)
ctx.addReusableOpenStatement(
s"$functionTerm.open(new $EXECUTION_CONTEXT_IMPL(null, getRuntimeContext()));")
ctx.addReusableCloseStatement(s"$functionTerm.close();")
}
private def generateAggInputExprEvaluation(
inputExprs: Seq[RexNode],
funcName: String): String = {
isWithinAggExprState = true
val resultTerm = newName("result")
val exprs = inputExprs.zipWithIndex.map {
case (inputExpr, outputIndex) =>
val expr = generateExpression(inputExpr)
s"""
| ${expr.code}
| if (${expr.nullTerm}) {
| $resultTerm.update($outputIndex, null);
| } else {
| $resultTerm.update($outputIndex, ${expr.resultTerm});
| }
""".stripMargin
}.mkString("\n")
isWithinAggExprState = false
j"""
|private $GENERIC_ROW $funcName($BASE_ROW $inputAggRowTerm) {
| $GENERIC_ROW $resultTerm = new $GENERIC_ROW(${inputExprs.size});
| $exprs
| return $resultTerm;
|}
|""".stripMargin
}
private case class SingleAggCall(
sqlAggFunction: SqlAggFunction,
resultType: RelDataType,
exprIndices: Seq[Integer]
)
private case class MatchAgg(
aggregations: Seq[SingleAggCall],
inputExprs: Seq[RexNode]
)
}
}
object MatchCodeGenerator {
val ALL_PATTERN_VARIABLE = "*"
val EXECUTION_CONTEXT_IMPL: String = className[ExecutionContextImpl]
val AGGS_HANDLE_FUNCTION: String = className[AggsHandleFunction]
val GENERIC_ROW: String = className[GenericRow]
val BASE_ROW: String = className[BaseRow]
val CURRENT_CLASS_LOADER = "Thread.currentThread().getContextClassLoader()"
def generateProctimeTimestamp(): GeneratedExpression = {
val resultTerm = newName("result")
val resultCode =
s"""
|long $resultTerm = System.currentTimeMillis();
|""".stripMargin.trim
GeneratedExpression(resultTerm, NEVER_NULL, resultCode, DataTypes.TIMESTAMP)
}
}