| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.flink.table.codegen |
| |
| import java.lang.reflect.Modifier |
| import java.lang.{Iterable => JIterable} |
| import java.util.{List => JList} |
| |
| import org.apache.calcite.rex.RexLiteral |
| import org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescriptor, State, StateDescriptor} |
| import org.apache.flink.api.common.typeinfo.TypeInformation |
| import org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument, getRawClass} |
| import org.apache.flink.table.api.TableConfig |
| import org.apache.flink.table.api.dataview._ |
| import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess} |
| import org.apache.flink.table.codegen.Indenter.toISC |
| import org.apache.flink.table.dataview.{StateListView, StateMapView} |
| import org.apache.flink.table.functions.AggregateFunction |
| import org.apache.flink.table.functions.aggfunctions.DistinctAccumulator |
| import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils |
| import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString} |
| import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, SingleElementIterable} |
| import org.apache.flink.table.utils.EncodingUtils |
| import org.apache.flink.types.Row |
| |
| import scala.collection.JavaConversions._ |
| import scala.collection.mutable |
| |
| /** |
| * A code generator for generating [[GeneratedAggregations]]. |
| * |
| * @param config configuration that determines runtime behavior |
| * @param nullableInput input(s) can be null. |
| * @param input type information about the input of the Function |
| * @param constants constant expressions that act like a second input in the parameter indices. |
| */ |
| class AggregationCodeGenerator( |
| config: TableConfig, |
| nullableInput: Boolean, |
| input: TypeInformation[_ <: Any], |
| constants: Option[Seq[RexLiteral]]) |
| extends CodeGenerator(config, nullableInput, input) { |
| |
| // set of statements for cleanup dataview that will be added only once |
| // we use a LinkedHashSet to keep the insertion order |
| private val reusableCleanupStatements = mutable.LinkedHashSet[String]() |
| |
| /** |
| * @return code block of statements that need to be placed in the cleanup() method of |
| * [[GeneratedAggregations]] |
| */ |
| def reuseCleanupCode(): String = { |
| reusableCleanupStatements.mkString("", "\n", "\n") |
| } |
| |
| /** |
| * Generates a [[org.apache.flink.table.runtime.aggregate.GeneratedAggregations]] that can be |
| * passed to a Java compiler. |
| * |
| * @param name Class name of the function. |
| * Does not need to be unique but has to be a valid Java class identifier. |
| * @param physicalInputTypes Physical input row types |
| * @param aggregates All aggregate functions |
| * @param aggFields Indexes of the input fields for all aggregate functions |
| * @param aggMapping The mapping of aggregates to output fields |
| * @param distinctAccMapping The mapping of the distinct accumulator index to the |
| * corresponding aggregates. |
| * @param isStateBackedDataViews a flag to indicate if distinct filter uses state backend. |
| * @param partialResults A flag defining whether final or partial results (accumulators) are set |
| * to the output row. |
| * @param fwdMapping The mapping of input fields to output fields |
| * @param mergeMapping An optional mapping to specify the accumulators to merge. If not set, we |
| * assume that both rows have the accumulators at the same position. |
| * @param outputArity The number of fields in the output row. |
| * @param needRetract a flag to indicate if the aggregate needs the retract method |
| * @param needMerge a flag to indicate if the aggregate needs the merge method |
| * @param needReset a flag to indicate if the aggregate needs the resetAccumulator method |
| * @param accConfig Data view specification for accumulators |
| * |
| * @return A GeneratedAggregationsFunction |
| */ |
| def generateAggregations( |
| name: String, |
| physicalInputTypes: Seq[TypeInformation[_]], |
| aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]], |
| aggFields: Array[Array[Int]], |
| aggMapping: Array[Int], |
| distinctAccMapping: Array[(Integer, JList[Integer])], |
| isStateBackedDataViews: Boolean, |
| partialResults: Boolean, |
| fwdMapping: Array[Int], |
| mergeMapping: Option[Array[Int]], |
| outputArity: Int, |
| needRetract: Boolean, |
| needMerge: Boolean, |
| needReset: Boolean, |
| accConfig: Option[Array[Seq[DataViewSpec[_]]]]) |
| : GeneratedAggregationsFunction = { |
| |
| // get unique function name |
| val funcName = newName(name) |
| // register UDAGGs |
| val aggs = aggregates.map(a => addReusableFunction(a, contextTerm)) |
| |
| // get java types of accumulators |
| val accTypeClasses = aggregates.map { a => |
| a.getClass.getMethod("createAccumulator").getReturnType |
| } |
| val accTypes = accTypeClasses.map(_.getCanonicalName) |
| |
| // create constants |
| val constantExprs = constants.map(_.map(generateExpression)).getOrElse(Seq()) |
| val constantTypes = constantExprs.map(_.resultType) |
| val constantFields = constantExprs.map(addReusableBoxedConstant) |
| |
| // get parameter lists for aggregation functions |
| val parametersCode = aggFields.map { inFields => |
| val fields = inFields.filter(_ > -1).map { f => |
| // index to constant |
| if (f >= physicalInputTypes.length) { |
| constantFields(f - physicalInputTypes.length) |
| } |
| // index to input field |
| else { |
| s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)" |
| } |
| } |
| |
| fields.mkString(", ") |
| } |
| |
| // get parameter lists for distinct acc, constant fields are not necessary |
| val parametersCodeForDistinctAcc = aggFields.map { inFields => |
| val fields = inFields.filter(i => i > -1 && i < physicalInputTypes.length).map { f => |
| // index to input field |
| s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) input.getField($f)" |
| } |
| |
| fields.mkString(", ") |
| } |
| |
| val parametersCodeForDistinctMerge = aggFields.map { inFields => |
| // transform inFields to pairs of (inField, index in acc) firstly, |
| // e.g. (4, 2, 3, 2) will be transformed to ((4,2), (2,0), (3,1), (2,0)) |
| val fields = inFields.filter(_ > -1).groupBy(identity).toSeq.sortBy(_._1).zipWithIndex |
| .flatMap { case (a, i) => a._2.map((_, i)) } |
| .map { case (f, i) => |
| // index to constant |
| if (f >= physicalInputTypes.length) { |
| constantFields(f - physicalInputTypes.length) |
| } |
| // index to input field |
| else { |
| s"(${CodeGenUtils.boxedTypeTermForTypeInfo(physicalInputTypes(f))}) k.getField($i)" |
| } |
| } |
| |
| fields.mkString(", ") |
| } |
| |
| // get method signatures |
| val classes = UserDefinedFunctionUtils.typeInfoToClass(physicalInputTypes) |
| val constantClasses = UserDefinedFunctionUtils.typeInfoToClass(constantTypes) |
| val methodSignaturesList = aggFields.map { inFields => |
| inFields.filter(_ > -1).map { f => |
| // index to constant |
| if (f >= physicalInputTypes.length) { |
| constantClasses(f - physicalInputTypes.length) |
| } |
| // index to input field |
| else { |
| classes(f) |
| } |
| } |
| } |
| |
| // get distinct filter of acc fields for each aggregate functions |
| val distinctAccType = s"${classOf[DistinctAccumulator].getName}" |
| |
| val distinctAccCount = distinctAccMapping.count(_._1 >= 0) |
| |
| if (distinctAccCount > 0 && partialResults && isStateBackedDataViews) { |
| // should not happen, but add an error message just in case. |
| throw new CodeGenException( |
| s"Cannot emit partial results if DISTINCT values are tracked in state-backed maps. " + |
| s"Please report this bug." |
| ) |
| } |
| |
| // initialize and create data views for accumulators & distinct filters |
| addAccumulatorDataViews() |
| |
| // check and validate the needed methods |
| aggregates.zipWithIndex.map { |
| case (a, i) => |
| getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i)) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching accumulate method found for AggregateFunction " + |
| s"'${a.getClass.getCanonicalName}'" + |
| s"with parameters '${signatureToString(methodSignaturesList(i))}'.") |
| ) |
| |
| if (needRetract) { |
| getUserDefinedMethod(a, "retract", Array(accTypeClasses(i)) ++ methodSignaturesList(i)) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching retract method found for AggregateFunction " + |
| s"'${a.getClass.getCanonicalName}'" + |
| s"with parameters '${signatureToString(methodSignaturesList(i))}'.") |
| ) |
| } |
| |
| if (needMerge) { |
| val method = |
| getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]])) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching merge method found for AggregateFunction " + |
| s"${a.getClass.getCanonicalName}'.") |
| ) |
| |
| // use the TypeExtractionUtils here to support nested GenericArrayTypes and |
| // other complex types |
| val iterableGenericType = extractTypeArgument(method.getGenericParameterTypes()(1), 0) |
| val iterableTypeClass = getRawClass(iterableGenericType) |
| |
| if (iterableTypeClass != accTypeClasses(i)) { |
| throw new CodeGenException( |
| s"Merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " + |
| s"the correct Iterable type. Actually: ${iterableTypeClass.toString}. " + |
| s"Expected: ${accTypeClasses(i).toString}") |
| } |
| } |
| |
| if (needReset) { |
| getUserDefinedMethod(a, "resetAccumulator", Array(accTypeClasses(i))) |
| .getOrElse( |
| throw new CodeGenException( |
| s"No matching resetAccumulator method found for " + |
| s"aggregate ${a.getClass.getCanonicalName}'.") |
| ) |
| } |
| } |
| |
| /** |
| * Add all data views for all field accumulators and distinct filters defined by |
| * aggregation functions. |
| */ |
| def addAccumulatorDataViews(): Unit = { |
| if (accConfig.isDefined) { |
| // create state handles for DataView backed accumulator fields. |
| val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get |
| .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor))) |
| .toMap[String, StateDescriptor[_ <: State, _]] |
| |
| for (i <- 0 until aggs.length + distinctAccCount) yield { |
| for (spec <- accConfig.get(i)) yield { |
| // Check if stat descriptor exists. |
| val desc: StateDescriptor[_, _] = descMapping.getOrElse(spec.stateId, |
| throw new CodeGenException( |
| s"Can not find DataView in accumulator by id: ${spec.stateId}")) |
| |
| addReusableDataView(spec, desc, i) |
| } |
| } |
| } |
| } |
| |
| /** |
| * Create DataView Term, for example, acc1_map_dataview. |
| * |
| * @param aggIndex index of aggregate function |
| * @param fieldName field name of DataView |
| * @return term to access [[MapView]] or [[ListView]] |
| */ |
| def createDataViewTerm(aggIndex: Int, fieldName: String): String = { |
| s"acc${aggIndex}_${fieldName}_dataview" |
| } |
| |
| /** |
| * Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] to the open, cleanup, |
| * close and member area of the generated function. |
| * @param spec the [[DataViewSpec]] of the desired data view term. |
| * @param desc the [[StateDescriptor]] of the desired data view term. |
| * @param aggIndex the aggregation function index associate with the data view. |
| */ |
| def addReusableDataView( |
| spec: DataViewSpec[_], |
| desc: StateDescriptor[_, _], |
| aggIndex: Int): Unit = { |
| val dataViewField = spec.field |
| val dataViewTypeTerm = dataViewField.getType.getCanonicalName |
| |
| // define the DataView variables |
| val serializedData = EncodingUtils.encodeObjectToString(desc) |
| val dataViewFieldTerm = createDataViewTerm(aggIndex, dataViewField.getName) |
| val field = |
| s""" |
| | final $dataViewTypeTerm $dataViewFieldTerm; |
| |""".stripMargin |
| reusableMemberStatements.add(field) |
| |
| // create DataViews |
| val descFieldTerm = s"${dataViewFieldTerm}_desc" |
| val descClassQualifier = classOf[StateDescriptor[_, _]].getCanonicalName |
| val descDeserializeCode = |
| s""" |
| | $descClassQualifier $descFieldTerm = ($descClassQualifier) |
| | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject( |
| | "$serializedData", |
| | $descClassQualifier.class, |
| | $contextTerm.getUserCodeClassLoader()); |
| |""".stripMargin |
| val createDataView = if (dataViewField.getType == classOf[MapView[_, _]]) { |
| s""" |
| | $descDeserializeCode |
| | $dataViewFieldTerm = new ${classOf[StateMapView[_, _]].getCanonicalName}( |
| | $contextTerm.getMapState( |
| | (${classOf[MapStateDescriptor[_, _]].getCanonicalName}) $descFieldTerm)); |
| |""".stripMargin |
| } else if (dataViewField.getType == classOf[ListView[_]]) { |
| s""" |
| | $descDeserializeCode |
| | $dataViewFieldTerm = new ${classOf[StateListView[_]].getCanonicalName}( |
| | $contextTerm.getListState( |
| | (${classOf[ListStateDescriptor[_]].getCanonicalName}) $descFieldTerm)); |
| |""".stripMargin |
| } else { |
| throw new CodeGenException(s"Unsupported dataview type: $dataViewTypeTerm") |
| } |
| reusableOpenStatements.add(createDataView) |
| |
| // cleanup DataViews |
| val cleanup = |
| s""" |
| | $dataViewFieldTerm.clear(); |
| |""".stripMargin |
| reusableCleanupStatements.add(cleanup) |
| } |
| |
| def genAccDataViewFieldSetter(str: String, i: Int): String = { |
| if (accConfig.isDefined) { |
| genDataViewFieldSetter(accConfig.get(i), str, i) |
| } else { |
| "" |
| } |
| } |
| |
| /** |
| * Generate statements to set data view field when use state backend. |
| * |
| * @param specs aggregation [[DataViewSpec]]s for this aggregation term. |
| * @param accTerm aggregation term |
| * @param aggIndex index of aggregation |
| * @return data view field set statements |
| */ |
| def genDataViewFieldSetter( |
| specs: Seq[DataViewSpec[_]], |
| accTerm: String, |
| aggIndex: Int): String = { |
| val setters = for (spec <- specs) yield { |
| val field = spec.field |
| val dataViewTerm = createDataViewTerm(aggIndex, field.getName) |
| val fieldSetter = if (Modifier.isPublic(field.getModifiers)) { |
| s"$accTerm.${field.getName} = $dataViewTerm;" |
| } else { |
| val fieldTerm = addReusablePrivateFieldAccess(field.getDeclaringClass, field.getName) |
| s"${reflectiveFieldWriteAccess(fieldTerm, field, accTerm, dataViewTerm)};" |
| } |
| |
| s""" |
| | $fieldSetter |
| """.stripMargin |
| } |
| setters.mkString("\n") |
| } |
| |
| def genSetAggregationResults: String = { |
| |
| val sig: String = |
| j""" |
| | public final void setAggregationResults( |
| | org.apache.flink.types.Row accs, |
| | org.apache.flink.types.Row output) throws Exception """.stripMargin |
| |
| val setAggs: String = { |
| for ((i, aggIndexes) <- distinctAccMapping) yield { |
| if (partialResults) { |
| def setAggs(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| j""" |
| |output.setField( |
| | ${aggMapping(i)}, |
| | (${accTypes(i)}) accs.getField($i)); |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| if (i >= 0) { |
| j""" |
| | output.setField( |
| | ${aggMapping(i)}, |
| | ($distinctAccType) accs.getField($i)); |
| | ${setAggs(aggIndexes)} |
| """.stripMargin |
| } else { |
| j""" |
| | ${setAggs(aggIndexes)} |
| """.stripMargin |
| } |
| } else { |
| def setAggs(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| val setAccOutput = |
| j""" |
| |${genAccDataViewFieldSetter(s"acc$i", i)} |
| |output.setField( |
| | ${aggMapping(i)}, |
| | baseClass$i.getValue(acc$i)); |
| """.stripMargin |
| |
| j""" |
| |org.apache.flink.table.functions.AggregateFunction baseClass$i = |
| | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)}; |
| |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); |
| |$setAccOutput |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| j""" |
| | ${setAggs(aggIndexes)} |
| """.stripMargin |
| } |
| } |
| }.mkString("\n") |
| |
| j""" |
| |$sig { |
| |$setAggs |
| | }""".stripMargin |
| } |
| |
| def genAccumulate: String = { |
| |
| val sig: String = |
| j""" |
| | public final void accumulate( |
| | org.apache.flink.types.Row accs, |
| | org.apache.flink.types.Row input) throws Exception """.stripMargin |
| |
| val accumulate: String = { |
| def accumulateAcc(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| j""" |
| |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); |
| |${genAccDataViewFieldSetter(s"acc$i", i)} |
| |${aggs(i)}.accumulate(acc$i |
| | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)}); |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| for ((i, aggIndexes) <- distinctAccMapping) yield { |
| if (i >= 0) { |
| j""" |
| | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i); |
| | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)} |
| | if (distinctAcc$i.add(${classOf[Row].getCanonicalName}.of( |
| | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) { |
| | ${accumulateAcc(aggIndexes)} |
| | } |
| """.stripMargin |
| } else { |
| j""" |
| | ${accumulateAcc(aggIndexes)} |
| """.stripMargin |
| } |
| } |
| }.mkString("\n") |
| |
| j"""$sig { |
| |$accumulate |
| | }""".stripMargin |
| } |
| |
| def genRetract: String = { |
| |
| val sig: String = |
| j""" |
| | public final void retract( |
| | org.apache.flink.types.Row accs, |
| | org.apache.flink.types.Row input) throws Exception """.stripMargin |
| |
| val retract: String = { |
| def retractAcc(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| j""" |
| |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); |
| |${genAccDataViewFieldSetter(s"acc$i", i)} |
| |${aggs(i)}.retract(acc$i |
| | ${if (!parametersCode(i).isEmpty) "," else ""} ${parametersCode(i)}); |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| for ((i, aggIndexes) <- distinctAccMapping) yield { |
| if (i >= 0) { |
| j""" |
| | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i); |
| | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)} |
| | if (distinctAcc$i.remove(${classOf[Row].getCanonicalName}.of( |
| | ${parametersCodeForDistinctAcc(aggIndexes.get(0))}))) { |
| | ${retractAcc(aggIndexes)} |
| | } |
| """.stripMargin |
| } else { |
| j""" |
| | ${retractAcc(aggIndexes)} |
| """.stripMargin |
| } |
| } |
| }.mkString("\n") |
| |
| if (needRetract) { |
| j""" |
| |$sig { |
| |$retract |
| | }""".stripMargin |
| } else { |
| j""" |
| |$sig { |
| | }""".stripMargin |
| } |
| } |
| |
| def genCreateAccumulators: String = { |
| |
| val sig: String = |
| j""" |
| | public final org.apache.flink.types.Row createAccumulators() throws Exception |
| | """.stripMargin |
| val init: String = |
| j""" |
| | org.apache.flink.types.Row accs = |
| | new org.apache.flink.types.Row(${aggs.length + distinctAccCount});""" |
| .stripMargin |
| val create: String = { |
| def createAcc(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| j""" |
| |${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator(); |
| |accs.setField( |
| | $i, |
| | acc$i); |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| for ((i, aggIndexes) <- distinctAccMapping) yield { |
| if (i >= 0) { |
| j""" |
| | $distinctAccType distinctAcc$i = ($distinctAccType) |
| | new ${classOf[DistinctAccumulator].getCanonicalName}(); |
| | accs.setField( |
| | $i, |
| | distinctAcc$i); |
| | ${createAcc(aggIndexes)} |
| """.stripMargin |
| } else { |
| j""" |
| | ${createAcc(aggIndexes)} |
| """.stripMargin |
| } |
| } |
| }.mkString("\n") |
| val ret: String = |
| j""" |
| | return accs;""" |
| .stripMargin |
| |
| j"""$sig { |
| |$init |
| |$create |
| |$ret |
| | }""".stripMargin |
| } |
| |
| def genSetForwardedFields: String = { |
| |
| val sig: String = |
| j""" |
| | public final void setForwardedFields( |
| | org.apache.flink.types.Row input, |
| | org.apache.flink.types.Row output) |
| | """.stripMargin |
| |
| val forward: String = { |
| for (i <- fwdMapping.indices if fwdMapping(i) >= 0) yield |
| { |
| j""" |
| | output.setField( |
| | $i, |
| | input.getField(${fwdMapping(i)}));""" |
| .stripMargin |
| } |
| }.mkString("\n") |
| |
| j"""$sig { |
| |$forward |
| | }""".stripMargin |
| } |
| |
| def genCreateOutputRow: String = { |
| j""" |
| | public final org.apache.flink.types.Row createOutputRow() { |
| | return new org.apache.flink.types.Row($outputArity); |
| | }""".stripMargin |
| } |
| |
| def genMergeAccumulatorsPair: String = { |
| val mapping = mergeMapping.getOrElse((0 until aggs.length + distinctAccCount).toArray) |
| |
| val sig: String = |
| j""" |
| | public final org.apache.flink.types.Row mergeAccumulatorsPair( |
| | org.apache.flink.types.Row a, |
| | org.apache.flink.types.Row b) |
| """.stripMargin |
| val merge: String = { |
| def accumulateAcc(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| j""" |
| |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i); |
| |${aggs(i)}.accumulate(aAcc$i, ${parametersCodeForDistinctMerge(i)}); |
| |a.setField($i, aAcc$i); |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| def mergeAcc(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| j""" |
| |${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i); |
| |${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)}); |
| |accIt$i.setElement(bAcc$i); |
| |${aggs(i)}.merge(aAcc$i, accIt$i); |
| |a.setField($i, aAcc$i); |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| for ((i, aggIndexes) <- distinctAccMapping) yield { |
| if (i >= 0) { |
| j""" |
| | $distinctAccType aDistinctAcc$i = ($distinctAccType) a.getField($i); |
| | $distinctAccType bDistinctAcc$i = ($distinctAccType) b.getField(${mapping(i)}); |
| | java.util.Iterator<java.util.Map.Entry> mergeIt$i = |
| | bDistinctAcc$i.elements().iterator(); |
| | |
| | while (mergeIt$i.hasNext()) { |
| | java.util.Map.Entry entry = (java.util.Map.Entry) mergeIt$i.next(); |
| | ${classOf[Row].getCanonicalName} k = |
| | (${classOf[Row].getCanonicalName}) entry.getKey(); |
| | Long v = (Long) entry.getValue(); |
| | if (aDistinctAcc$i.add(k, v)) { |
| | ${accumulateAcc(aggIndexes)} |
| | } |
| | } |
| | a.setField($i, aDistinctAcc$i); |
| """.stripMargin |
| } else { |
| j""" |
| | ${mergeAcc(aggIndexes)} |
| """.stripMargin |
| } |
| } |
| }.mkString("\n") |
| val ret: String = |
| j""" |
| | return a; |
| """.stripMargin |
| |
| if (needMerge) { |
| if (accConfig.isDefined) { |
| throw new CodeGenException("DataView doesn't support merge when the backend uses " + |
| s"state when generate aggregation for $funcName.") |
| } |
| j""" |
| |$sig { |
| |$merge |
| |$ret |
| | }""".stripMargin |
| } else { |
| j""" |
| |$sig { |
| |$ret |
| | }""".stripMargin |
| } |
| } |
| |
| def genMergeList: String = { |
| { |
| val singleIterableClass = classOf[SingleElementIterable[_]].getCanonicalName |
| for (i <- accTypes.indices) yield |
| j""" |
| | private final $singleIterableClass<${accTypes(i)}> accIt$i = |
| | new $singleIterableClass<${accTypes(i)}>(); |
| """.stripMargin |
| }.mkString("\n") |
| } |
| |
| def genResetAccumulator: String = { |
| |
| val sig: String = |
| j""" |
| | public final void resetAccumulator( |
| | org.apache.flink.types.Row accs) throws Exception """.stripMargin |
| |
| val reset: String = { |
| def resetAcc(aggIndexes: JList[Integer]) = { |
| for (i <- aggIndexes) yield { |
| j""" |
| |${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i); |
| |${genAccDataViewFieldSetter(s"acc$i", i)} |
| |${aggs(i)}.resetAccumulator(acc$i); |
| """.stripMargin |
| } |
| }.mkString("\n") |
| |
| for ((i, aggIndexes) <- distinctAccMapping) yield { |
| if (i >= 0) { |
| j""" |
| | $distinctAccType distinctAcc$i = ($distinctAccType) accs.getField($i); |
| | ${genAccDataViewFieldSetter(s"distinctAcc$i", i)} |
| | distinctAcc$i.reset(); |
| | ${resetAcc(aggIndexes)} |
| """.stripMargin |
| } else { |
| j""" |
| | ${resetAcc(aggIndexes)} |
| """.stripMargin |
| } |
| } |
| }.mkString("\n") |
| |
| if (needReset) { |
| j"""$sig { |
| |$reset |
| | }""".stripMargin |
| } else { |
| j"""$sig { |
| | }""".stripMargin |
| } |
| } |
| |
| val aggFuncCode = Seq( |
| genSetAggregationResults, |
| genAccumulate, |
| genRetract, |
| genCreateAccumulators, |
| genSetForwardedFields, |
| genCreateOutputRow, |
| genMergeAccumulatorsPair, |
| genResetAccumulator).mkString("\n") |
| |
| val generatedAggregationsClass = classOf[GeneratedAggregations].getCanonicalName |
| val funcCode = |
| j""" |
| |public final class $funcName extends $generatedAggregationsClass { |
| | |
| | ${reuseMemberCode()} |
| | $genMergeList |
| | public $funcName() throws Exception { |
| | ${reuseInitCode()} |
| | } |
| | ${reuseConstructorCode(funcName)} |
| | |
| | public final void open( |
| | org.apache.flink.api.common.functions.RuntimeContext $contextTerm) throws Exception { |
| | ${reuseOpenCode()} |
| | } |
| | |
| | $aggFuncCode |
| | |
| | public final void cleanup() throws Exception { |
| | ${reuseCleanupCode()} |
| | } |
| | |
| | public final void close() throws Exception { |
| | ${reuseCloseCode()} |
| | } |
| |} |
| """.stripMargin |
| |
| GeneratedAggregationsFunction(funcName, funcCode) |
| } |
| } |