blob: 1cce7310ddfac7e848e90774c39e44adc76faebf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.codegen.agg
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.metrics.Gauge
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.api.functions.{AggregateFunction, DeclarativeAggregateFunction, UserDefinedFunction}
import org.apache.flink.table.api.types.{DataTypes, InternalType, RowType}
import org.apache.flink.table.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull}
import org.apache.flink.table.codegen._
import org.apache.flink.table.codegen.operator.OperatorCodeGenerator
import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, GenericRow, JoinedRow}
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.util.SortUtil
import org.apache.flink.table.runtime.sort.{BufferedKVExternalSorter, NormalizedKeyComputer, RecordComparator}
import org.apache.flink.table.runtime.util.{BytesHashMap, BytesHashMapSpillMemorySegmentPool}
import org.apache.flink.table.typeutils.{BinaryRowSerializer, TypeUtils}
import org.apache.flink.table.util.NodeResourceUtil
trait BatchExecHashAggregateCodeGen extends BatchExecAggregateCodeGen {
private[flink] def prepareHashAggKVTypes(
ctx: CodeGeneratorContext,
aggMapKeyTypesTerm: String,
aggBufferTypesTerm: String,
aggMapKeyType: RowType,
aggBufferType: RowType): Unit = {
val tpTerm = classOf[InternalType].getName
ctx.addReusableMember(
s"private transient $tpTerm[] $aggMapKeyTypesTerm;",
s"$aggMapKeyTypesTerm = ${ctx.addReferenceObj(
aggMapKeyType.getFieldInternalTypes, s"$tpTerm[]")};")
ctx.addReusableMember(
s"private transient $tpTerm[] $aggBufferTypesTerm;",
s"$aggBufferTypesTerm = ${ctx.addReferenceObj(
aggBufferType.getFieldInternalTypes, s"$tpTerm[]")};")
}
private[flink] def prepareHashAggMap(
ctx: CodeGeneratorContext,
config: TableConfig,
reservedManagedMemory: Long,
maxManagedMemory: Long,
groupKeyTypesTerm: String,
aggBufferTypesTerm: String,
aggregateMapTerm: String): Unit = {
// allocate memory segments for aggregate map
// create aggregate map
val mapTypeTerm = classOf[BytesHashMap].getName
val perRequestSize = NodeResourceUtil.getPerRequestManagedMemory(config.getConf) *
NodeResourceUtil.SIZE_IN_MB
ctx.addReusableMember(s"private transient $mapTypeTerm $aggregateMapTerm;")
ctx.addReusableOpenStatement(s"$aggregateMapTerm " +
s"= new $mapTypeTerm(" +
s"this.getContainingTask()," +
s"this.getContainingTask().getEnvironment().getMemoryManager()," +
s"${reservedManagedMemory}L," +
s"${maxManagedMemory}L," +
s"${perRequestSize}L," +
s" $groupKeyTypesTerm," +
s" $aggBufferTypesTerm);")
// close aggregate map and release memory segments
ctx.addReusableCloseStatement(s"$aggregateMapTerm.free();")
ctx.addReusableCloseStatement(s"")
}
def getOutputRowClass: Class[_ <: BaseRow]
private[flink] def prepareTermForAggMapIteration(
ctx: CodeGeneratorContext,
outputTerm: String,
outputType: RowType,
aggMapKeyType: RowType,
aggBufferType: RowType): (String, String, String) = {
// prepare iteration var terms
val reuseAggMapKeyTerm = CodeGenUtils.newName("reuseAggMapKey")
val reuseAggBufferTerm = CodeGenUtils.newName("reuseAggBuffer")
val reuseAggMapEntryTerm = CodeGenUtils.newName("reuseAggMapEntry")
// gen code to prepare agg output using agg buffer and key from the aggregate map
val binaryRow = classOf[BinaryRow].getName
val mapEntryTypeTerm = classOf[BytesHashMap.Entry].getCanonicalName
ctx.addOutputRecord(outputType, getOutputRowClass, outputTerm)
ctx.addReusableMember(
s"private transient $binaryRow $reuseAggMapKeyTerm = " +
s"new $binaryRow(${aggMapKeyType.getArity});")
ctx.addReusableMember(
s"private transient $binaryRow $reuseAggBufferTerm = " +
s"new $binaryRow(${aggBufferType.getArity});")
ctx.addReusableMember(
s"private transient $mapEntryTypeTerm $reuseAggMapEntryTerm = " +
s"new $mapEntryTypeTerm($reuseAggMapKeyTerm, $reuseAggBufferTerm);"
)
(reuseAggMapEntryTerm, reuseAggMapKeyTerm, reuseAggBufferTerm)
}
/**
* Generate codes which will read aggregation map,
* get the aggregate values
*/
private[flink] def genAggMapIterationAndOutput(
ctx: CodeGeneratorContext,
config: TableConfig,
isFinal: Boolean,
aggregateMapTerm: String,
reuseAggMapEntryTerm: String,
reuseAggBufferTerm: String,
outputExpr: GeneratedExpression): String = {
// gen code to iterating the aggregate map and output to downstream
val inputUnboxingCode =
if (isFinal) s"${ctx.reuseInputUnboxingCode(Set(reuseAggBufferTerm))}" else ""
val iteratorTerm = CodeGenUtils.newName("iterator")
val mapEntryTypeTerm = classOf[BytesHashMap.Entry].getCanonicalName
s"""
|org.apache.flink.util.MutableObjectIterator<$mapEntryTypeTerm> $iteratorTerm =
| $aggregateMapTerm.getEntryIterator();
|while ($iteratorTerm.next($reuseAggMapEntryTerm) != null) {
| // set result and output
| $inputUnboxingCode
| ${outputExpr.code}
| ${OperatorCodeGenerator.generatorCollect(outputExpr.resultTerm)}
|}
""".stripMargin
}
// ===============================================================================================
/**
* In the cases of hash aggregation,
* we store the aggregate buffer as BytesHashMap's value in the form of BinaryRow.
* We use an index to locate the aggregate buffer field.
*/
private[flink] def bindReference(
isMerge: Boolean,
offset: Int,
agg: DeclarativeAggregateFunction,
aggIndex: Int,
argsMapping: Array[Array[(Int, InternalType)]],
aggBuffMapping: Array[Array[(Int, InternalType)]])
: PartialFunction[Expression, Expression] = {
case input: UnresolvedFieldReference =>
// We always use UnresolvedFieldReference to represent reference of input field.
// In non-merge case, the input is operand of the aggregate function. But in merge
// case, the input is aggregate buffers which sent by local aggregate.
val localIndex = if (isMerge) agg.inputAggBufferAttributes.indexOf(input)
else agg.operands.indexOf(input)
val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex)
ResolvedAggInputReference(input.name, inputIndex, inputType)
case aggBuffAttr: UnresolvedAggBufferReference =>
val localIndex = agg.aggBufferAttributes.indexOf(aggBuffAttr)
val (aggBuffAttrIndex, aggBuffAttrType) = aggBuffMapping(aggIndex)(localIndex)
ResolvedAggInputReference(
aggBuffAttr.name, offset + aggBuffAttrIndex, aggBuffAttrType)
}
/**
* Generate codes which will read input,
* accumulating aggregate buffers and updating the aggregation map
*/
private[flink] def genAccumulateAggBuffer(
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
inputRelDataType: RelDataType,
inputTerm: String,
inputType: RowType,
currentAggBufferTerm: String,
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
argsMapping: Array[Array[(Int, InternalType)]],
aggBuffMapping: Array[Array[(Int, InternalType)]],
aggBufferType: RowType): GeneratedExpression = {
val exprCodegen = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(inputType, inputTerm = inputTerm)
.bindSecondInput(aggBufferType, inputTerm = currentAggBufferTerm)
val accumulateExprsWithFilterArgs = aggCallToAggFunction.zipWithIndex.flatMap {
case (aggCallToAggFun, aggIndex) =>
val idx = auxGrouping.length + aggIndex
val bindRefOffset = inputRelDataType.getFieldCount
val aggCall = aggCallToAggFun._1
aggCallToAggFun._2 match {
case agg: DeclarativeAggregateFunction =>
agg.accumulateExpressions.map(
_.postOrderTransform(bindReference(
isMerge = false, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))
).map(e => (e, aggCall))
}
}.map {
case (expr: Expression, aggCall: AggregateCall) =>
(exprCodegen.generateExpression(expr.toRexNode(builder)), aggCall.filterArg)
}
// update agg buff in-place
val code = accumulateExprsWithFilterArgs.zipWithIndex.map({
case ((accumulateExpr, filterArg), index) =>
val idx = auxGrouping.length + index
val t = aggBufferType.getInternalTypeAt(idx)
val writeCode = binaryRowFieldSetAccess(
idx, currentAggBufferTerm, t.toInternalType, accumulateExpr.resultTerm)
val innerCode = if (config.getNullCheck) {
s"""
|${accumulateExpr.code}
|if (${accumulateExpr.nullTerm}) {
| ${binaryRowSetNull(idx, currentAggBufferTerm, t.toInternalType)};
|} else {
| $writeCode;
|}
|""".stripMargin.trim
}
else {
s"""
|${accumulateExpr.code}
|$writeCode;
|""".stripMargin.trim
}
if (filterArg >= 0) {
s"""
|if ($inputTerm.getBoolean($filterArg)) {
| $innerCode
|}
""".stripMargin
} else {
innerCode
}
}) mkString "\n"
GeneratedExpression(currentAggBufferTerm, "false", code, aggBufferType.toInternalType)
}
/**
* Generate codes which will init the empty agg buffer.
*/
private[flink] def genReusableEmptyAggBuffer(
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
inputTerm: String,
inputType: RowType,
auxGrouping: Array[Int],
aggregates: Seq[UserDefinedFunction],
aggBufferType: RowType): GeneratedExpression = {
val exprCodegen = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(inputType, inputTerm = inputTerm, inputFieldMapping = Some(auxGrouping))
val initAuxGroupingExprs = auxGrouping.map { idx =>
CodeGenUtils.generateFieldAccess(
ctx, inputType.toInternalType, inputTerm, idx, nullCheck = true)
}
val initAggCallBufferExprs = aggregates.flatMap(a =>
a.asInstanceOf[DeclarativeAggregateFunction].initialValuesExpressions)
.map(_.toRexNode(builder))
.map(exprCodegen.generateExpression)
val initAggBufferExprs = initAuxGroupingExprs ++ initAggCallBufferExprs
// empty agg buffer and writer will be reused
val emptyAggBufferTerm = CodeGenUtils.newName("emptyAggBuffer")
val emptyAggBufferWriterTerm = CodeGenUtils.newName("emptyAggBufferWriterTerm")
exprCodegen.generateResultExpression(
initAggBufferExprs,
aggBufferType,
classOf[BinaryRow],
emptyAggBufferTerm,
Some(emptyAggBufferWriterTerm)
)
}
/**
* Generate codes which will read input,
* merge aggregate buffers and update the aggregation map
*/
private[flink] def genMergeAggBuffer(
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
inputRelDataType: RelDataType,
inputTerm: String,
inputType: RowType,
currentAggBufferTerm: String,
auxGrouping: Array[Int],
aggregates: Seq[UserDefinedFunction],
argsMapping: Array[Array[(Int, InternalType)]],
aggBuffMapping: Array[Array[(Int, InternalType)]],
aggBufferType: RowType): GeneratedExpression = {
val exprCodegen = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(inputType.toInternalType, inputTerm = inputTerm)
.bindSecondInput(aggBufferType.toInternalType, inputTerm = currentAggBufferTerm)
val mergeExprs = aggregates.zipWithIndex.flatMap {
case (agg: DeclarativeAggregateFunction, aggIndex) =>
val idx = auxGrouping.length + aggIndex
val bindRefOffset = inputRelDataType.getFieldCount
agg.mergeExpressions.map(
_.postOrderTransform(bindReference(
isMerge = true, bindRefOffset, agg, idx, argsMapping, aggBuffMapping)))
}.map(_.toRexNode(builder)).map(exprCodegen.generateExpression)
val aggBufferTypeWithoutAuxGrouping = if (auxGrouping.nonEmpty) {
// auxGrouping does not need merge-code
new RowType(
aggBufferType.getFieldTypes.slice(auxGrouping.length, aggBufferType.getArity),
aggBufferType.getFieldNames.slice(auxGrouping.length, aggBufferType.getArity))
} else {
aggBufferType
}
val mergeExprIdxToOutputRowPosMap = mergeExprs.indices.map{
i => i -> (i + auxGrouping.length)
}.toMap
// update agg buff in-place
exprCodegen.generateResultExpression(
mergeExprs,
mergeExprIdxToOutputRowPosMap,
aggBufferTypeWithoutAuxGrouping,
classOf[BinaryRow],
outRow = currentAggBufferTerm,
outRowWriter = None,
reusedOutRow = true,
outRowAlreadyExists = true
)
}
private[flink] def genAggregate(
isMerge: Boolean,
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
inputRelDataType: RelDataType,
inputType: RowType,
inputTerm: String,
auxGrouping: Array[Int],
aggregates: Seq[UserDefinedFunction],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
argsMapping: Array[Array[(Int, InternalType)]],
aggBuffMapping: Array[Array[(Int, InternalType)]],
currentAggBufferTerm: String,
aggBufferRowType: RowType): GeneratedExpression = {
if (isMerge) {
genMergeAggBuffer(ctx, config, builder, inputRelDataType, inputTerm, inputType,
currentAggBufferTerm, auxGrouping, aggregates, argsMapping, aggBuffMapping,
aggBufferRowType)
} else {
genAccumulateAggBuffer(ctx, config, builder, inputRelDataType, inputTerm, inputType,
currentAggBufferTerm, auxGrouping, aggCallToAggFunction, argsMapping, aggBuffMapping,
aggBufferRowType)
}
}
private[flink] def genHashAggOutputExpr(
isMerge: Boolean,
isFinal: Boolean,
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
inputRelDataType: RelDataType,
auxGrouping: Array[Int],
aggregates: Seq[UserDefinedFunction],
argsMapping: Array[Array[(Int, InternalType)]],
aggBuffMapping: Array[Array[(Int, InternalType)]],
outputTerm: String,
outputType: RowType,
inputTerm: String,
inputType: RowType,
groupKeyTerm: Option[String],
aggBufferTerm: String,
aggBufferType: RowType): GeneratedExpression = {
// gen code to get agg result
val exprCodegen = new ExprCodeGenerator(ctx, false, config.getNullCheck)
.bindInput(inputType.toInternalType, inputTerm = inputTerm)
.bindSecondInput(aggBufferType.toInternalType, inputTerm = aggBufferTerm)
val resultExpr = if (isFinal) {
val bindRefOffset = inputRelDataType.getFieldCount
val getAuxGroupingExprs = auxGrouping.indices.map { idx =>
val (_, resultType) = aggBuffMapping(idx)(0)
ResolvedAggInputReference("aux_group", bindRefOffset + idx, resultType)
}.map(_.toRexNode(builder)).map(exprCodegen.generateExpression)
val getAggValueExprs = aggregates.zipWithIndex.map {
case (agg: DeclarativeAggregateFunction, aggIndex) =>
val idx = auxGrouping.length + aggIndex
agg.getValueExpression.postOrderTransform(
bindReference(isMerge, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))
}.map(_.toRexNode(builder)).map(exprCodegen.generateExpression)
val getValueExprs = getAuxGroupingExprs ++ getAggValueExprs
val aggValueTerm = CodeGenUtils.newName("aggVal")
val valueType = new RowType(getValueExprs.map(_.resultType): _*)
exprCodegen.generateResultExpression(
getValueExprs,
valueType,
classOf[GenericRow],
aggValueTerm)
} else {
new GeneratedExpression(aggBufferTerm, "false", "", aggBufferType)
}
// add grouping keys if exists
groupKeyTerm match {
case Some(key) =>
val output =
s"""
|${resultExpr.code}
|$outputTerm.replace($key, ${resultExpr.resultTerm});
""".stripMargin
new GeneratedExpression(outputTerm, "false", output, outputType)
case _ => resultExpr
}
}
private[flink] def genHashAggCodes(
isMerge: Boolean,
isFinal: Boolean,
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
groupingAndAuxGrouping: (Array[Int], Array[Int]),
inputRelDataType: RelDataType,
inputTerm: String,
inputType: RowType,
aggregateCalls: Seq[AggregateCall],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
aggregates: Seq[UserDefinedFunction],
currentAggBufferTerm: String,
aggBufferRowType: RowType,
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[InternalType]],
outputTerm: String,
outputType: RowType,
groupKeyTerm: String,
aggBufferTerm: String): (GeneratedExpression, GeneratedExpression, GeneratedExpression) = {
val (grouping, auxGrouping) = groupingAndAuxGrouping
// build mapping for DeclarativeAggregationFunction binding references
val argsMapping = buildAggregateArgsMapping(
isMerge, grouping.length, inputRelDataType, auxGrouping, aggregateCalls, aggBufferTypes)
val aggBuffMapping = buildAggregateAggBuffMapping(aggBufferTypes)
// gen code to create empty agg buffer
val initedAggBuffer = genReusableEmptyAggBuffer(
ctx, config, builder, inputTerm, inputType, auxGrouping, aggregates, aggBufferRowType)
if (auxGrouping.isEmpty) {
// create an empty agg buffer and initialized make it reusable
ctx.addReusableOpenStatement(initedAggBuffer.code)
}
// gen code to update agg buffer from the aggregate map
val aggregate = genAggregate(isMerge, ctx, config, builder, inputRelDataType,
inputType, inputTerm, auxGrouping, aggregates, aggCallToAggFunction,
argsMapping, aggBuffMapping, currentAggBufferTerm, aggBufferRowType)
val outputExpr = genHashAggOutputExpr(isMerge, isFinal, ctx, config, builder, inputRelDataType,
auxGrouping, aggregates, argsMapping, aggBuffMapping, outputTerm, outputType, inputTerm,
inputType, Some(groupKeyTerm), aggBufferTerm, aggBufferRowType)
(initedAggBuffer, aggregate, outputExpr)
}
// ===============================================================================================
private[flink] def genRetryAppendToMap(
aggregateMapTerm: String,
currentKeyTerm: String,
initedAggBuffer: GeneratedExpression,
lookupInfo: String,
currentAggBufferTerm: String): String = {
s"""
| // reset aggregate map retry append
|$aggregateMapTerm.reset();
|$lookupInfo = $aggregateMapTerm.lookup($currentKeyTerm);
|try {
| $currentAggBufferTerm =
| $aggregateMapTerm.append($lookupInfo, ${initedAggBuffer.resultTerm});
|} catch (java.io.EOFException e) {
| throw new OutOfMemoryError("BytesHashMap Out of Memory.");
|}
""".stripMargin
}
private[flink] def genAggMapOOMHandling(
isFinal: Boolean,
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
groupingAndAuxGrouping: (Array[Int], Array[Int]),
inputRelDataType: RelDataType,
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
aggregates: Seq[UserDefinedFunction],
udaggs: Map[AggregateFunction[_, _], String],
logTerm: String,
aggregateMapTerm: String,
aggMapKVTypesTerm: (String, String),
aggMapKVRowType: (RowType, RowType),
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[InternalType]],
outputTerm: String,
outputType: RowType,
outputResultFromMap: String,
sorterTerm: String,
retryAppend: String): (String, String) = {
val (grouping, auxGrouping) = groupingAndAuxGrouping
if (isFinal) {
val logMapSpilling =
CodeGenUtils.genLogInfo(
logTerm, s"BytesHashMap out of memory with {} entries, start spilling.",
s"$aggregateMapTerm.getNumElements()")
// gen fallback to sort agg
val (groupKeyTypesTerm, aggBufferTypesTerm) = aggMapKVTypesTerm
val (groupKeyRowType, aggBufferRowType) = aggMapKVRowType
prepareFallbackSorter(ctx, sorterTerm)
val createSorter = genCreateFallbackSorter(
ctx, groupKeyRowType, groupKeyTypesTerm, aggBufferTypesTerm, sorterTerm)
val fallbackToSortAggCode = genFallbackToSortAgg(
ctx, config, builder, grouping, auxGrouping, inputRelDataType, aggCallToAggFunction,
aggregates, udaggs, aggregateMapTerm, (groupKeyRowType, aggBufferRowType), aggregateMapTerm,
sorterTerm, outputTerm, outputType, aggBufferNames, aggBufferTypes)
val memPoolTypeTerm = classOf[BytesHashMapSpillMemorySegmentPool].getName
val dealWithAggHashMapOOM =
s"""
|$logMapSpilling
| // hash map out of memory, spill to external sorter
|if ($sorterTerm == null) {
| $createSorter
|}
| // sort and spill
|$sorterTerm.sortAndSpill(
| $aggregateMapTerm.getRecordAreaMemorySegments(),
| $aggregateMapTerm.getNumElements(),
| new $memPoolTypeTerm($aggregateMapTerm.getBucketAreaMemorySegments()));
| // retry append
|$retryAppend
""".stripMargin
(dealWithAggHashMapOOM, fallbackToSortAggCode)
} else {
val logMapOutput =
CodeGenUtils.genLogInfo(
logTerm, s"BytesHashMap out of memory with {} entries, output directly.",
s"$aggregateMapTerm.getNumElements()")
val dealWithAggHashMapOOM =
s"""
|$logMapOutput
| // hash map out of memory, output directly
|$outputResultFromMap
| // retry append
|$retryAppend
""".stripMargin
(dealWithAggHashMapOOM, "")
}
}
private[flink] def prepareFallbackSorter(ctx: CodeGeneratorContext, sorterTerm: String): Unit = {
val sorterTypeTerm = classOf[BufferedKVExternalSorter].getName
ctx.addReusableMember(s"transient $sorterTypeTerm $sorterTerm;")
ctx.addReusableCloseStatement(s"if ($sorterTerm != null) $sorterTerm.close();")
}
private[flink] def prepareMetrics(
ctx: CodeGeneratorContext, hashTerm: String, sorterTerm: String): Unit = {
val gauge = classOf[Gauge[_]].getCanonicalName
val longType = classOf[java.lang.Long].getCanonicalName
val numSpillFiles =
s"""
|getMetricGroup().gauge("numSpillFiles", new $gauge<$longType>() {
| @Override
| public $longType getValue() {
| return $hashTerm.getNumSpillFiles();
| }
| });
""".stripMargin.trim
val memoryUsedSizeInBytes =
s"""
|getMetricGroup().gauge("memoryUsedSizeInBytes", new $gauge<$longType>() {
| @Override
| public $longType getValue() {
| return $hashTerm.getUsedMemoryInBytes();
| }
| });
""".stripMargin.trim
ctx.addReusableOpenStatement(numSpillFiles)
ctx.addReusableOpenStatement(memoryUsedSizeInBytes)
if (sorterTerm != null) {
val spillInBytes =
s"""
| getMetricGroup().gauge("spillInBytes", new $gauge<$longType>() {
| @Override
| public $longType getValue() {
| return $hashTerm.getSpillInBytes();
| }
|});
""".stripMargin.trim
ctx.addReusableOpenStatement(spillInBytes)
}
}
private[flink] def genCreateFallbackSorter(
ctx: CodeGeneratorContext,
groupKeyRowType: RowType,
groupKeyTypesTerm: String,
aggBufferTypesTerm: String,
sorterTerm: String): String = {
val keyComputerTerm = CodeGenUtils.newName("keyComputer")
val recordComparatorTerm = CodeGenUtils.newName("recordComparator")
val prepareSorterCode = genKVSorterPrepareCode(
ctx, keyComputerTerm, recordComparatorTerm, groupKeyRowType)
val binaryRowSerializerTypeTerm = classOf[BinaryRowSerializer].getName
val sorterTypeTerm = classOf[BufferedKVExternalSorter].getName
s"""
| $prepareSorterCode
| $sorterTerm = new $sorterTypeTerm(
| getContainingTask().getEnvironment().getIOManager(),
| new $binaryRowSerializerTypeTerm($groupKeyTypesTerm),
| new $binaryRowSerializerTypeTerm($aggBufferTypesTerm),
| $keyComputerTerm, $recordComparatorTerm,
| getContainingTask().getEnvironment().getMemoryManager().getPageSize(),
| getSqlConf()
| );
""".stripMargin
}
private[flink] def genFallbackToSortAgg(
ctx: CodeGeneratorContext,
config: TableConfig,
builder: RelBuilder,
grouping: Array[Int],
auxGrouping: Array[Int],
inputRelDataType: RelDataType,
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
aggregates: Seq[UserDefinedFunction],
udaggs: Map[AggregateFunction[_, _], String],
mapTerm: String,
mapKVRowTypes: (RowType, RowType),
aggregateMapTerm: String,
sorterTerm: String,
outputTerm: String,
outputType: RowType,
aggBufferNames: Array[Array[String]],
aggBufferTypes: Array[Array[InternalType]]): String = {
val (groupKeyRowType, aggBufferRowType) = mapKVRowTypes
val keyTerm = CodeGenUtils.newName("key")
val lastKeyTerm = CodeGenUtils.newName("lastKey")
val keyNotEquals = genGroupKeyChangedCheckCode(keyTerm, lastKeyTerm)
val joinedRow = classOf[JoinedRow].getName
val fallbackInputTerm = ctx.newReusableField("fallbackInput", joinedRow)
val fallbackInputType = new RowType(
groupKeyRowType.getFieldTypes ++ aggBufferRowType.getFieldTypes,
groupKeyRowType.getFieldNames ++ aggBufferRowType.getFieldNames)
val (initAggBufferCode, updateAggBufferCode, resultExpr) = genSortAggCodes(
isMerge = true, isFinal = true, ctx, config, builder, grouping, auxGrouping, inputRelDataType,
aggCallToAggFunction, aggregates, udaggs, fallbackInputTerm, fallbackInputType,
aggBufferNames, aggBufferTypes, outputType, forHashAgg = true)
val kvPairTerm = CodeGenUtils.newName("kvPair")
val kvPairTypeTerm = classOf[JTuple2[BinaryRow, BinaryRow]].getName
val aggBuffTerm = CodeGenUtils.newName("val")
val binaryRow = classOf[BinaryRow].getName
s"""
| $binaryRow $lastKeyTerm = null;
| $kvPairTypeTerm<$binaryRow, $binaryRow> $kvPairTerm = null;
| $binaryRow $keyTerm = null;
| $binaryRow $aggBuffTerm = null;
| $fallbackInputTerm = new $joinedRow();
|
| // free hash map memory, but not release back to memory manager
|
| org.apache.flink.util.MutableObjectIterator<$kvPairTypeTerm<$binaryRow, $binaryRow>>
| iterator = $sorterTerm.getKVIterator();
|
| while (
| ($kvPairTerm = ($kvPairTypeTerm<$binaryRow, $binaryRow>) iterator.next()) != null) {
| $keyTerm = ($binaryRow) $kvPairTerm.f0;
| $aggBuffTerm = ($binaryRow) $kvPairTerm.f1;
| // prepare input
| $fallbackInputTerm.replace($keyTerm, $aggBuffTerm);
| if ($lastKeyTerm == null) {
| // found first key group
| $lastKeyTerm = $keyTerm.copy();
| $initAggBufferCode
| } else if ($keyNotEquals) {
| // output current group aggregate result
| ${resultExpr.code}
| $outputTerm.replace($lastKeyTerm, ${resultExpr.resultTerm});
| ${OperatorCodeGenerator.generatorCollect(outputTerm)}
| // found new group
| $lastKeyTerm = $keyTerm.copy();
| $initAggBufferCode
| }
| // reusable field access codes for agg buffer merge
| ${ctx.reuseInputUnboxingCode(Set(fallbackInputTerm))}
| // merge aggregate map's value into aggregate buffer fields
| $updateAggBufferCode
| }
|
| // output last key group aggregate result
| ${resultExpr.code}
| $outputTerm.replace($lastKeyTerm, ${resultExpr.resultTerm});
| ${OperatorCodeGenerator.generatorCollect(outputTerm)}
""".stripMargin
}
private[flink] def genKVSorterPrepareCode(
ctx: CodeGeneratorContext,
keyComputerTerm: String,
recordComparatorTerm: String,
aggMapKeyType: RowType) : String = {
val keyFieldTypes = aggMapKeyType.getFieldInternalTypes
val keys = keyFieldTypes.indices.toArray
val orders = keys.map((_) => true)
val nullsIsLast = SortUtil.getNullDefaultOrders(orders)
val (comparators, serializers) = TypeUtils.flattenComparatorAndSerializer(
keyFieldTypes.length, keys, orders, keyFieldTypes)
val sortCodeGenerator = new SortCodeGenerator(
keys, keyFieldTypes, comparators, orders, nullsIsLast)
val genedSorter = GeneratedSorter(
sortCodeGenerator.generateNormalizedKeyComputer("AggMapKeyComputer"),
sortCodeGenerator.generateRecordComparator("AggMapValueComparator"),
serializers,
comparators)
val keyComputerTypeTerm = classOf[NormalizedKeyComputer].getName
val keyComputeInnerClassTerm = genedSorter.computer.name
val recordComparatorTypeTerm = classOf[RecordComparator].getName
val recordComparatorInnerClassTerm = genedSorter.comparator.name
ctx.addReusableInnerClass(keyComputeInnerClassTerm, genedSorter.computer.code)
ctx.addReusableInnerClass(recordComparatorInnerClassTerm, genedSorter.comparator.code)
val serArrayTerm = s"${classOf[TypeSerializer[_]].getCanonicalName}[]"
val compArrayTerm = s"${classOf[TypeComparator[_]].getCanonicalName}[]"
val serializersTerm = ctx.addReferenceObj(serializers, serArrayTerm)
val comparatorsTerm = ctx.addReferenceObj(comparators, compArrayTerm)
s"""
| $keyComputerTypeTerm $keyComputerTerm = new $keyComputeInnerClassTerm();
| $recordComparatorTypeTerm $recordComparatorTerm = new $recordComparatorInnerClassTerm();
| $keyComputerTerm.init($serializersTerm, $comparatorsTerm);
| $recordComparatorTerm.init($serializersTerm, $comparatorsTerm);
|""".stripMargin
}
}