| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.flink.table.codegen.agg |
| |
| import java.lang.reflect.ParameterizedType |
| import java.lang.{Iterable => JIterable} |
| |
| import org.apache.calcite.tools.RelBuilder |
| import org.apache.flink.runtime.util.SingleElementIterator |
| import org.apache.flink.table.api.functions.AggregateFunction |
| import org.apache.flink.table.api.types.{RowType, DataType, DataTypes, InternalType} |
| import org.apache.flink.table.codegen.CodeGenUtils._ |
| import org.apache.flink.table.codegen.{CodeGenException, CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression} |
| import org.apache.flink.table.codegen.agg.AggsHandlerCodeGenerator._ |
| import org.apache.flink.table.expressions.ResolvedDistinctKeyReference |
| import org.apache.flink.table.dataformat.{GenericRow, UpdatableRow} |
| import org.apache.flink.table.dataview.DataViewSpec |
| import org.apache.flink.table.expressions.{Expression, ResolvedAggInputReference} |
| import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ |
| import org.apache.flink.table.plan.util.AggregateInfo |
| import org.apache.flink.table.runtime.conversion.DataStructureConverters.{genToExternal, genToInternal} |
| import org.apache.flink.table.typeutils.TypeUtils |
| |
| import scala.collection.mutable.ArrayBuffer |
| |
| /** |
| * It is for code generate aggregation functions that are specified in terms of |
| * accumulate(), retract() and merge() functions. The aggregate accumulator is |
| * embedded inside of a larger shared aggregation buffer. |
| * |
| * @param ctx the code gen context |
| * @param aggInfo the aggregate information |
| * @param filterExpression filter argument access expression, none if no filter |
| * @param mergedAccOffset the mergedAcc may come from local aggregate, |
| * this is the first buffer offset in the row |
| * @param aggBufferOffset the offset in the buffers of this aggregate |
| * @param aggBufferSize the total size of aggregate buffers |
| * @param inputTypes the input field type infos |
| * @param constantExprs the constant expressions |
| * @param relBuilder the rel builder to translate expressions to calcite rex nodes |
| * @param hasNamespace whether the accumulators state has namespace |
| * @param inputFieldCopy copy input field element if true (only mutable type will be copied) |
| */ |
| class ImperativeAggCodeGen( |
| ctx: CodeGeneratorContext, |
| aggInfo: AggregateInfo, |
| filterExpression: Option[Expression], |
| mergedAccOffset: Int, |
| aggBufferOffset: Int, |
| aggBufferSize: Int, |
| inputTypes: Seq[InternalType], |
| constantExprs: Seq[GeneratedExpression], |
| relBuilder: RelBuilder, |
| hasNamespace: Boolean, |
| mergedAccOnHeap: Boolean, |
| mergedAccExternalType: DataType, |
| inputFieldCopy: Boolean) |
| extends AggCodeGen { |
| |
| private val SINGLE_ITERABLE = className[SingleElementIterator[_]] |
| private val UPDATABLE_ROW = className[UpdatableRow] |
| |
| val function: AggregateFunction[_, _] = aggInfo.function.asInstanceOf[AggregateFunction[_, _]] |
| val functionTerm: String = ctx.addReusableFunction( |
| function, |
| constructorTerm = s"$CONTEXT_TERM.getRuntimeContext()") |
| val aggIndex: Int = aggInfo.aggIndex |
| |
| val externalAccType = aggInfo.externalAccTypes(0) |
| |
| /** whether the acc type is an internal type. |
| * Currently we only support GenericRow as internal acc type */ |
| val isAccTypeInternal: Boolean = externalAccType match { |
| // current we only support GenericRow as internal ACC type |
| case t: RowType => true |
| case _ => false |
| } |
| |
| val accInternalTerm: String = s"agg${aggIndex}_acc_internal" |
| val accExternalTerm: String = s"agg${aggIndex}_acc_external" |
| val accTypeInternalTerm: String = if (isAccTypeInternal) { |
| GENERIC_ROW |
| } else { |
| boxedTypeTermForType(externalAccType.toInternalType) |
| } |
| val accTypeExternalTerm: String = externalBoxedTermForType(externalAccType) |
| |
| val argTypes: Array[InternalType] = { |
| val types = inputTypes ++ constantExprs.map(_.resultType) |
| aggInfo.argIndexes.map(types(_)) |
| } |
| |
| val resultType: DataType = aggInfo.externalResultType |
| |
| val viewSpecs: Array[DataViewSpec] = aggInfo.viewSpecs |
| // add reusable dataviews to context |
| addReusableStateDataViews(ctx, viewSpecs, hasNamespace, !mergedAccOnHeap) |
| |
| def createAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = { |
| // do not set dataview into the acc in createAccumulator |
| val accField = if (isAccTypeInternal) { |
| // do not need convert to internal type |
| s"$functionTerm.createAccumulator()" |
| } else { |
| genToInternal(ctx, externalAccType, s"$functionTerm.createAccumulator()") |
| } |
| val accInternal = newName("acc_internal") |
| val code = s"$accTypeInternalTerm $accInternal = $accField;" |
| Seq(GeneratedExpression(accInternal, "false", code, externalAccType.toInternalType)) |
| } |
| |
| def setAccumulator(generator: ExprCodeGenerator): String = { |
| // generate internal acc field |
| val expr = generateAccumulatorAccess( |
| ctx, |
| generator.input1Type, |
| generator.input1Term, |
| aggBufferOffset, |
| viewSpecs, |
| useStateDataView = true, |
| useBackupDataView = false, |
| nullCheck = true) |
| |
| if (isAccTypeInternal) { |
| ctx.addReusableMember(s"private $accTypeInternalTerm $accInternalTerm;") |
| s""" |
| |$accInternalTerm = ${expr.resultTerm}; |
| """.stripMargin |
| } else { |
| ctx.addReusableMember(s"private $accTypeInternalTerm $accInternalTerm;") |
| ctx.addReusableMember(s"private $accTypeExternalTerm $accExternalTerm;") |
| s""" |
| |$accInternalTerm = ${expr.resultTerm}; |
| |$accExternalTerm = ${genToExternal(ctx, externalAccType, accInternalTerm)}; |
| """.stripMargin |
| } |
| } |
| |
| override def resetAccumulator(generator: ExprCodeGenerator): String = { |
| if (isAccTypeInternal) { |
| s"$accInternalTerm = $functionTerm.createAccumulator();" |
| } else { |
| s""" |
| |$accExternalTerm = $functionTerm.createAccumulator(); |
| |$accInternalTerm = ${genToInternal(ctx, externalAccType, accExternalTerm)}; |
| """.stripMargin |
| } |
| } |
| |
| def getAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = { |
| val code = if (isAccTypeInternal) { |
| // do not need convert to internal type |
| "" |
| } else { |
| s"$accInternalTerm = ${genToInternal(ctx, externalAccType, accExternalTerm)};" |
| } |
| Seq(GeneratedExpression(accInternalTerm, "false", code, externalAccType.toInternalType)) |
| } |
| |
| def accumulate(generator: ExprCodeGenerator): String = { |
| val (parameters, code) = aggParametersCode(generator) |
| val call = s"$functionTerm.accumulate($parameters);" |
| filterExpression match { |
| case None => |
| s""" |
| |$code |
| |$call |
| """.stripMargin |
| case Some(expr) => |
| val generated = generator.generateExpression(expr.toRexNode(relBuilder)) |
| s""" |
| |if (${generated.resultTerm}) { |
| | $code |
| | $call |
| |} |
| """.stripMargin |
| } |
| } |
| |
| def retract(generator: ExprCodeGenerator): String = { |
| val (parameters, code) = aggParametersCode(generator) |
| val call = s"$functionTerm.retract($parameters);" |
| filterExpression match { |
| case None => |
| s""" |
| |$code |
| |$call |
| """.stripMargin |
| case Some(expr) => |
| val generated = generator.generateExpression(expr.toRexNode(relBuilder)) |
| s""" |
| |if (${generated.resultTerm}) { |
| | $code |
| | $call |
| |} |
| """.stripMargin |
| } |
| } |
| |
| def merge(generator: ExprCodeGenerator): String = { |
| val accIterTerm = s"agg${aggIndex}_acc_iter" |
| ctx.addReusableMember(s"private final $SINGLE_ITERABLE $accIterTerm = new $SINGLE_ITERABLE();") |
| |
| // generate internal acc field |
| val expr = generateAccumulatorAccess( |
| ctx, |
| generator.input1Type, |
| generator.input1Term, |
| mergedAccOffset + aggBufferOffset, |
| viewSpecs, |
| useStateDataView = !mergedAccOnHeap, |
| useBackupDataView = true, |
| nullCheck = true) |
| |
| if (isAccTypeInternal) { |
| s""" |
| |$accIterTerm.set(${expr.resultTerm}); |
| |$functionTerm.merge($accInternalTerm, $accIterTerm); |
| """.stripMargin |
| } else { |
| val otherAccExternal = newName("other_acc_external") |
| s""" |
| |$accTypeExternalTerm $otherAccExternal = ${ |
| genToExternal(ctx, mergedAccExternalType, expr.resultTerm)}; |
| |$accIterTerm.set($otherAccExternal); |
| |$functionTerm.merge($accExternalTerm, $accIterTerm); |
| """.stripMargin |
| } |
| } |
| |
| def getValue(generator: ExprCodeGenerator): GeneratedExpression = { |
| val valueExternalTerm = newName("value_external") |
| val valueExternalTypeTerm = externalBoxedTermForType(resultType) |
| val valueInternalTerm = newName("value_internal") |
| val valueInternalTypeTerm = boxedTypeTermForType(resultType.toInternalType) |
| val nullTerm = newName("valueIsNull") |
| val accTerm = if (isAccTypeInternal) accInternalTerm else accExternalTerm |
| val code = |
| s""" |
| |$valueExternalTypeTerm $valueExternalTerm = ($valueExternalTypeTerm) |
| | $functionTerm.getValue($accTerm); |
| |$valueInternalTypeTerm $valueInternalTerm = |
| | ${genToInternal(ctx, resultType, valueExternalTerm)}; |
| |boolean $nullTerm = $valueInternalTerm == null; |
| """.stripMargin |
| |
| GeneratedExpression(valueInternalTerm, nullTerm, code, resultType.toInternalType) |
| } |
| |
| private def aggParametersCode(generator: ExprCodeGenerator): (String, String) = { |
| val externalUDITypes = getAggUserDefinedInputTypes( |
| function, |
| externalAccType, |
| argTypes) |
| var codes: ArrayBuffer[String] = ArrayBuffer.empty[String] |
| val inputFields = aggInfo.argIndexes.zipWithIndex.map { case (f, index) => |
| if (f >= inputTypes.length) { |
| // index to constant |
| val expr = constantExprs(f - inputTypes.length) |
| s"${expr.nullTerm} ? null : ${ |
| genToExternal(ctx, externalUDITypes(index), expr.resultTerm)}" |
| } else { |
| // index to input field |
| val inputRef = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) { |
| if (argTypes.length == 1) { |
| // called from distinct merge and the inputTerm is the only argument |
| ResolvedDistinctKeyReference(generator.input1Term, inputTypes(f)) |
| } else { |
| // called from distinct merge call and the inputTerm is BaseRow type |
| ResolvedAggInputReference(f.toString, index, inputTypes(f)) |
| } |
| } else { |
| // called from accumulate |
| ResolvedAggInputReference(f.toString, f, inputTypes(f)) |
| } |
| val inputExpr = generator |
| .generateExpression(inputRef.toRexNode(relBuilder)) |
| // TODO: the copy result is not resued (i.e. MAX(a), MIN(a)) |
| .copyResultIfNeeded(ctx, inputFieldCopy) |
| codes += inputExpr.code |
| var term = s"${genToExternal(ctx, externalUDITypes(index), inputExpr.resultTerm)}" |
| // TODO: we have to keep this currently, because copy maybe not enabled |
| // TODO: maybe we want to provide an CopyOption including (none, deepCopy, shallowCopy) |
| // TODO: deepCopy means copy result, shallowCopy means only copy the reference |
| if (needCloneRefForDataType(externalUDITypes(index))) { |
| term = s"$term.cloneReference()" |
| } |
| s"${inputExpr.nullTerm} ? null : $term" |
| } |
| } |
| |
| val accTerm = if (isAccTypeInternal) accInternalTerm else accExternalTerm |
| // insert acc to the head of the list |
| val fields = Seq(accTerm) ++ inputFields |
| // acc, arg1, arg2 |
| (fields.mkString(", "), codes.mkString("\n")) |
| } |
| |
| /** |
| * This method is mainly the same as CodeGenUtils.generateFieldAccess(), the only difference is |
| * that this method using UpdatableRow to wrap BaseRow to handle DataViews. |
| */ |
| def generateAccumulatorAccess( |
| ctx: CodeGeneratorContext, |
| inputType: InternalType, |
| inputTerm: String, |
| index: Int, |
| viewSpecs: Array[DataViewSpec], |
| useStateDataView: Boolean, |
| useBackupDataView: Boolean, |
| nullCheck: Boolean): GeneratedExpression = { |
| |
| // if input has been used before, we can reuse the code that |
| // has already been generated |
| val inputExpr = ctx.getReusableInputUnboxingExprs(inputTerm, index) match { |
| // input access and unboxing has already been generated |
| case Some(expr) => expr |
| |
| // generate input access and unboxing if necessary |
| case None => |
| // this field access is not need to reuse |
| val expr = generateFieldAccess(ctx, inputType, inputTerm, index, nullCheck) |
| |
| val newExpr = inputType match { |
| case ct: RowType if isAccTypeInternal => |
| // acc is never be null |
| val fieldType = ct.getInternalTypeAt(index).asInstanceOf[RowType] |
| val exprGenerator = new ExprCodeGenerator(ctx, false, nullCheck) |
| .bindInput(fieldType, inputTerm = expr.resultTerm) |
| val converted = exprGenerator.generateConverterResultExpression( |
| fieldType, |
| classOf[GenericRow], |
| outRecordTerm = newName("acc"), |
| reusedOutRow = false, |
| fieldCopy = inputFieldCopy) |
| val code = |
| s""" |
| |${expr.code} |
| |${ctx.reuseInputUnboxingCode(Set(expr.resultTerm))} |
| |${converted.code} |
| """.stripMargin |
| |
| GeneratedExpression( |
| converted.resultTerm, |
| converted.nullTerm, |
| code, |
| converted.resultType) |
| case _ => expr |
| } |
| |
| val exprWithDataView = inputType match { |
| case ct: RowType if viewSpecs.nonEmpty && useStateDataView => |
| if (isAccTypeInternal) { |
| val code = |
| s""" |
| |${newExpr.code} |
| |${generateDataViewFieldSetter(newExpr.resultTerm, viewSpecs, useBackupDataView)} |
| """.stripMargin |
| GeneratedExpression(newExpr.resultTerm, newExpr.nullTerm, code, newExpr.resultType) |
| } else { |
| val fieldType = ct.getInternalTypeAt(index) |
| val fieldTerm = ctx.newReusableField("field", UPDATABLE_ROW) |
| val code = |
| s""" |
| |${newExpr.code} |
| |$fieldTerm = null; |
| |if (!${newExpr.nullTerm}) { |
| | $fieldTerm = new $UPDATABLE_ROW(${newExpr.resultTerm}, ${ |
| TypeUtils.getArity(fieldType)}); |
| | ${generateDataViewFieldSetter(fieldTerm, viewSpecs, useBackupDataView)} |
| |} |
| """.stripMargin |
| GeneratedExpression(fieldTerm, newExpr.nullTerm, code, newExpr.resultType) |
| } |
| |
| case _ => newExpr |
| } |
| |
| ctx.addReusableInputUnboxingExprs(inputTerm, index, exprWithDataView) |
| exprWithDataView |
| } |
| // hide the generated code as it will be executed only once |
| GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType) |
| } |
| |
| |
| /** |
| * Generate statements to set data view field when use state backend. |
| * |
| * @param accTerm aggregation term |
| * @return data view field set statements |
| */ |
| private def generateDataViewFieldSetter( |
| accTerm: String, |
| viewSpecs: Array[DataViewSpec], |
| useBackupDataView: Boolean): String = { |
| ctx.addAllReusableFields(Set(s"$BASE_ROW $CURRENT_KEY = ctx.currentKey();")) |
| val setters = for (spec <- viewSpecs) yield { |
| if (hasNamespace) { |
| val dataViewTerm = if (useBackupDataView) { |
| createDataViewBackupTerm(spec) |
| } else { |
| createDataViewTerm(spec) |
| } |
| |
| s""" |
| |// when namespace is null, the dataview is used in heap, no key and namespace set |
| |if ($NAMESPACE_TERM != null) { |
| | $dataViewTerm.setCurrentKey($CURRENT_KEY); |
| | $dataViewTerm.setCurrentNamespace($NAMESPACE_TERM); |
| | $accTerm.update(${spec.fieldIndex}, $dataViewTerm); |
| |} |
| """.stripMargin |
| } else { |
| val dataViewTerm = createDataViewTerm(spec) |
| |
| s""" |
| |$dataViewTerm.setCurrentKey($CURRENT_KEY); |
| |$accTerm.update(${spec.fieldIndex}, $dataViewTerm); |
| """.stripMargin |
| } |
| } |
| setters.mkString("\n") |
| } |
| |
| |
| |
| def checkNeededMethods( |
| needAccumulate: Boolean = false, |
| needRetract: Boolean = false, |
| needMerge: Boolean = false, |
| needReset: Boolean = false): Unit = { |
| |
| val methodSignatures = typesToClasses(argTypes) |
| |
| if (needAccumulate) { |
| getAggFunctionUDIMethod(function, "accumulate", externalAccType, argTypes) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching accumulate method found for AggregateFunction " + |
| s"'${function.getClass.getCanonicalName}'" + |
| s"with parameters '${signatureToString(methodSignatures)}'.") |
| ) |
| } |
| |
| if (needRetract) { |
| getAggFunctionUDIMethod(function, "retract", externalAccType, argTypes) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching retract method found for AggregateFunction " + |
| s"'${function.getClass.getCanonicalName}'" + |
| s"with parameters '${signatureToString(methodSignatures)}'.") |
| ) |
| } |
| |
| if (needMerge) { |
| val iterType = DataTypes.extractDataType(classOf[JIterable[Any]]) |
| val methods = |
| getUserDefinedMethod(function, "merge", Array(externalAccType, iterType)) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching merge method found for AggregateFunction " + |
| s"${function.getClass.getCanonicalName}'.") |
| ) |
| |
| var iterableTypeClass = methods.getGenericParameterTypes.apply(1) |
| .asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0) |
| // further extract iterableTypeClass if the accumulator has generic type |
| iterableTypeClass match { |
| case impl: ParameterizedType => iterableTypeClass = impl.getRawType |
| case _ => |
| } |
| |
| if (iterableTypeClass != TypeUtils.getExternalClassForType(externalAccType) && |
| // iterableTypeClass can be GenericRow, so classOf[BaseRow] is assignable from it. |
| !TypeUtils.getInternalClassForType(externalAccType).isAssignableFrom( |
| iterableTypeClass.asInstanceOf[Class[_]])) { |
| throw new CodeGenException( |
| s"merge method in AggregateFunction ${function.getClass.getCanonicalName} does not " + |
| s"have the correct Iterable type. Actually: $iterableTypeClass. " + |
| s"Expected: ${TypeUtils.getExternalClassForType(externalAccType)}") |
| } |
| } |
| |
| if (needReset) { |
| getUserDefinedMethod(function, "resetAccumulator", Array(externalAccType)) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching resetAccumulator method found for " + |
| s"aggregate ${function.getClass.getCanonicalName}'.") |
| ) |
| } |
| } |
| } |