blob: 13266f414abde2ab2f783565bbab0abd406985b9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.codegen.agg
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.api.functions.DeclarativeAggregateFunction
import org.apache.flink.table.api.types.{DataType, DataTypes, InternalType}
import org.apache.flink.table.codegen.CodeGenUtils.primitiveTypeTermForType
import org.apache.flink.table.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression}
import org.apache.flink.table.codegen.agg.AggsHandlerCodeGenerator._
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.util.AggregateInfo
/**
* It is for code generate aggregation functions that are specified using expressions.
* The aggregate buffer is embedded inside of a larger shared aggregation buffer.
*
* @param ctx the code gen context
* @param aggInfo the aggregate information
* @param filterExpression filter argument access expression, none if no filter
* @param mergedAccOffset the mergedAcc may come from local aggregate,
* this is the first buffer offset in the row
* @param aggBufferOffset the offset in the buffers of this aggregate
* @param aggBufferSize the total size of aggregate buffers
* @param inputTypes the input field type infos
* @param constantExprs the constant expressions
* @param relBuilder the rel builder to translate expressions to calcite rex nodes
*/
class DeclarativeAggCodeGen(
ctx: CodeGeneratorContext,
aggInfo: AggregateInfo,
filterExpression: Option[Expression],
mergedAccOffset: Int,
aggBufferOffset: Int,
aggBufferSize: Int,
inputTypes: Seq[InternalType],
constantExprs: Seq[GeneratedExpression],
relBuilder: RelBuilder)
extends AggCodeGen {
val function: DeclarativeAggregateFunction =
aggInfo.function.asInstanceOf[DeclarativeAggregateFunction]
val bufferTypes: Array[DataType] = aggInfo.externalAccTypes
val bufferIndexes: Array[Int] = Array.range(
aggBufferOffset, aggBufferOffset + bufferTypes.length)
val bufferTerms: Array[String] = function.aggBufferAttributes
.map(a => s"agg${aggInfo.aggIndex}_${a.name}").toArray
val bufferNullTerms: Array[String] = bufferTerms.map(_ + "_isNull")
val argIndexes: Array[Int] = aggInfo.argIndexes
val argTypes: Array[InternalType] = {
val types = inputTypes ++ constantExprs.map(_.resultType)
argIndexes.map(types(_))
}
def createAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = {
function.initialValuesExpressions
.map(expr => generator.generateExpression(expr.toRexNode(relBuilder)))
}
def setAccumulator(generator: ExprCodeGenerator): String = {
val aggBufferAccesses = function.aggBufferAttributes.zipWithIndex
.map { case (attr, index) =>
ResolvedAggInputReference(
attr.name, bufferIndexes(index), bufferTypes(index).toInternalType)
}
.map(expr => generator.generateExpression(expr.toRexNode(relBuilder)))
val setters = aggBufferAccesses.zipWithIndex.map {
case (access, index) =>
val typeTerm = primitiveTypeTermForType(access.resultType)
val memberName = bufferTerms(index)
val memberNullTerm = bufferNullTerms(index)
ctx.addReusableMember(s"private $typeTerm $memberName;")
ctx.addReusableMember(s"private boolean $memberNullTerm;")
s"""
|${access.copyResultTermToTargetIfChanged(ctx, memberName)};
|$memberNullTerm = ${access.nullTerm};
""".stripMargin
}
setters.mkString("\n")
}
override def resetAccumulator(generator: ExprCodeGenerator): String = {
val initialExprs = function.initialValuesExpressions
.map(expr => generator.generateExpression(expr.toRexNode(relBuilder)))
val codes = initialExprs.zipWithIndex.map {
case (init, index) =>
val memberName = bufferTerms(index)
val memberNullTerm = bufferNullTerms(index)
s"""
|${init.code}
|$memberName = ${init.resultTerm};
|$memberNullTerm = ${init.nullTerm};
""".stripMargin
}
codes.mkString("\n")
}
def getAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = {
bufferTypes.zipWithIndex.map { case (bufferType, index) =>
GeneratedExpression(
bufferTerms(index), bufferNullTerms(index), "", bufferType.toInternalType)
}
}
def accumulate(generator: ExprCodeGenerator): String = {
val resolvedExprs = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) {
// called from distinct merge
function.accumulateExpressions
.map(_.postOrderTransform(resolveReference(isDistinctMerge = true)))
} else {
// called from accumulate
function.accumulateExpressions
.map(_.postOrderTransform(resolveReference()))
}
val exprs = resolvedExprs
.map(_.toRexNode(relBuilder)) // rex nodes
.map(generator.generateExpression) // generated expressions
val codes = exprs.zipWithIndex.map { case (expr, index) =>
s"""
|${expr.code}
|${expr.copyResultTermToTargetIfChanged(ctx, bufferTerms(index))};
|${bufferNullTerms(index)} = ${expr.nullTerm};
""".stripMargin
}
filterExpression match {
case Some(expr) =>
val generated = generator.generateExpression(expr.toRexNode(relBuilder))
s"""
|if (${generated.resultTerm}) {
| ${codes.mkString("\n")}
|}
""".stripMargin
case None =>
codes.mkString("\n")
}
}
def retract(generator: ExprCodeGenerator): String = {
val resolvedExprs = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) {
// called from distinct merge
function.retractExpressions
.map(_.postOrderTransform(resolveReference(isDistinctMerge = true)))
} else {
// called from retract
function.retractExpressions
.map(_.postOrderTransform(resolveReference()))
}
val exprs = resolvedExprs
.map(_.toRexNode(relBuilder)) // rex nodes
.map(generator.generateExpression) // generated expressions
val codes = exprs.zipWithIndex.map { case (expr, index) =>
s"""
|${expr.code}
|${expr.copyResultTermToTargetIfChanged(ctx, bufferTerms(index))};
|${bufferNullTerms(index)} = ${expr.nullTerm};
""".stripMargin
}
filterExpression match {
case Some(expr) =>
val generated = generator.generateExpression(expr.toRexNode(relBuilder))
s"""
|if (${generated.resultTerm}) {
| ${codes.mkString("\n")}
|}
""".stripMargin
case None =>
codes.mkString("\n")
}
}
def merge(generator: ExprCodeGenerator): String = {
val exprs = function.mergeExpressions
.map(_.postOrderTransform(resolveReference(isMerge = true)))
.map(_.toRexNode(relBuilder)) // rex nodes
.map(generator.generateExpression) // generated expressions
val codes = exprs.zipWithIndex.map { case (expr, index) =>
s"""
|${expr.code}
|${expr.copyResultTermToTargetIfChanged(ctx, bufferTerms(index))};
|${bufferNullTerms(index)} = ${expr.nullTerm};
""".stripMargin
}
codes.mkString("\n")
}
def getValue(generator: ExprCodeGenerator): GeneratedExpression = {
val resolvedGetValueExpression = function.getValueExpression
.postOrderTransform(resolveReference())
generator.generateExpression(resolvedGetValueExpression.toRexNode(relBuilder))
}
/**
* Resolves the given expression to a [[NamedExpression]].
*
* @param isMerge this is called from merge() method
* @param isDistinctMerge this is called from distinct merge method
*/
private def resolveReference(isMerge: Boolean = false, isDistinctMerge: Boolean = false)
: PartialFunction[Expression, Expression] = {
case input: UnresolvedFieldReference =>
// We always use UnresolvedFieldReference to represent reference of input field.
// In non-merge case, the input is the operand of the aggregate function.
// In merge case, the input is the aggregate buffers sent by local aggregate.
if (isMerge) {
val localIndex = function.inputAggBufferAttributes.indexOf(input)
// in merge case, the input1 is mergedAcc
ResolvedAggInputReference(
input.name,
mergedAccOffset + bufferIndexes(localIndex),
bufferTypes(localIndex).toInternalType)
} else {
val localIndex = function.operands.indexOf(input)
val inputIndex = argIndexes(localIndex)
// index to constant
if (inputIndex >= inputTypes.length) {
val constantIndex = inputIndex - inputTypes.length
val constantTerm = constantExprs(constantIndex).resultTerm
val nullTerm = constantExprs(constantIndex).nullTerm
val constantType = constantExprs(constantIndex).resultType
// constant is reused as member variable
ResolvedAggLocalReference(constantTerm, nullTerm, constantType)
} else {
if (isDistinctMerge) { // this is called from distinct merge
if (function.inputCount == 1) {
// the distinct key is a BoxedValue
ResolvedDistinctKeyReference(input.name, argTypes(0))
} else {
// the distinct key is a BaseRow
ResolvedAggInputReference(input.name, localIndex, argTypes(localIndex))
}
} else {
// the input is the inputRow
ResolvedAggInputReference(input.name, argIndexes(localIndex), argTypes(localIndex))
}
}
}
case buffer: UnresolvedAggBufferReference =>
val localIndex = function.aggBufferAttributes.indexOf(buffer)
val name = bufferTerms(localIndex)
val nullTerm = bufferNullTerms(localIndex)
// buffer access is reused as member variable
ResolvedAggLocalReference(name, nullTerm, bufferTypes(localIndex).toInternalType)
}
def checkNeededMethods(
needAccumulate: Boolean = false,
needRetract: Boolean = false,
needMerge: Boolean = false,
needReset: Boolean = false): Unit = {
// skip the check for DeclarativeAggregateFunction for now
}
}