blob: 53748f5c71cd0eb9569329d213ffa1577b8a0e8c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.dataset
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo}
import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment, TableConfig}
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.AggregationCodeGenerator
import org.apache.flink.table.expressions.ExpressionUtils._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.nodes.CommonAggregate
import org.apache.flink.table.runtime.aggregate.AggregateUtil.{CalcitePair, _}
import org.apache.flink.table.typeutils.TypeCheckUtils.{isLong, isTimePoint}
import org.apache.flink.types.Row
/**
* Flink RelNode which matches along with a LogicalWindowAggregate.
*/
class DataSetWindowAggregate(
window: LogicalWindow,
namedProperties: Seq[NamedWindowProperty],
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputNode: RelNode,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
rowRelDataType: RelDataType,
inputType: RelDataType,
grouping: Array[Int])
extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataSetRel {
override def deriveRowType() = rowRelDataType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataSetWindowAggregate(
window,
namedProperties,
cluster,
traitSet,
inputs.get(0),
namedAggregates,
getRowType,
inputType,
grouping)
}
override def toString: String = {
s"Aggregate(${
if (!grouping.isEmpty) {
s"groupBy: (${groupingToString(inputType, grouping)}), "
} else {
""
}
}window: ($window), " +
s"select: (${
aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties)
}))"
}
override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
.itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty)
.item("window", window)
.item(
"select", aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties))
}
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val child = this.getInput
val rowCnt = metadata.getRowCount(child)
val rowSize = this.estimateRowSize(child.getRowType)
val aggCnt = this.namedAggregates.size
planner.getCostFactory.makeCost(rowCnt, rowCnt * aggCnt, rowCnt * rowSize)
}
override def translateToPlan(
tableEnv: BatchTableEnvironment,
queryConfig: BatchQueryConfig): DataSet[Row] = {
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv, queryConfig)
val generator = new AggregationCodeGenerator(
tableEnv.getConfig,
false,
inputDS.getType,
None)
// whether identifiers are matched case-sensitively
val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive()
window match {
case TumblingGroupWindow(_, timeField, size)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeTumblingWindowDataSet(
generator,
inputDS,
isTimeIntervalLiteral(size),
caseSensitive,
tableEnv.getConfig)
case SessionGroupWindow(_, timeField, gap)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive, tableEnv.getConfig)
case SlidingGroupWindow(_, timeField, size, slide)
if isTimePoint(timeField.resultType) || isLong(timeField.resultType) =>
createEventTimeSlidingWindowDataSet(
generator,
inputDS,
isTimeIntervalLiteral(size),
asLong(size),
asLong(slide),
caseSensitive,
tableEnv.getConfig)
case _ =>
throw new UnsupportedOperationException(
s"Window $window is not supported in a batch environment.")
}
}
private def createEventTimeTumblingWindowDataSet(
generator: AggregationCodeGenerator,
inputDS: DataSet[Row],
isTimeWindow: Boolean,
isParserCaseSensitive: Boolean,
tableConfig: TableConfig): DataSet[Row] = {
val input = inputNode.asInstanceOf[DataSetRel]
val mapFunction = createDataSetWindowPrepareMapFunction(
generator,
window,
namedAggregates,
grouping,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive,
tableConfig)
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
getRowType,
grouping,
namedProperties,
tableConfig)
val mappedInput = inputDS
.map(mapFunction)
.name(prepareOperatorName)
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType
if (isTimeWindow) {
// grouped time window aggregation
// group by grouping keys and rowtime field (the last field in the row)
val groupingKeys = grouping.indices ++ Seq(mapReturnType.getArity - 1)
mappedInput.asInstanceOf[DataSet[Row]]
.groupBy(groupingKeys: _*)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
} else {
// count window
val groupingKeys = grouping.indices.toArray
if (groupingKeys.length > 0) {
// grouped aggregation
mappedInput.asInstanceOf[DataSet[Row]]
.groupBy(groupingKeys: _*)
// sort on time field, it's the last element in the row
.sortGroup(mapReturnType.getArity - 1, Order.ASCENDING)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
} else {
// TODO: count tumbling all window on event-time should sort all the data set
// on event time before applying the windowing logic.
throw new UnsupportedOperationException(
"Count tumbling non-grouping windows on event-time are currently not supported.")
}
}
}
private[this] def createEventTimeSessionWindowDataSet(
generator: AggregationCodeGenerator,
inputDS: DataSet[Row],
isParserCaseSensitive: Boolean,
tableConfig: TableConfig): DataSet[Row] = {
val input = inputNode.asInstanceOf[DataSetRel]
val groupingKeys = grouping.indices.toArray
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
// create mapFunction for initializing the aggregations
val mapFunction = createDataSetWindowPrepareMapFunction(
generator,
window,
namedAggregates,
grouping,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive,
tableConfig)
val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName)
val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType
// the position of the rowtime field in the intermediate result for map output
val rowTimeFieldPos = mapReturnType.getArity - 1
// do incremental aggregation
if (doAllSupportPartialMerge(
namedAggregates.map(_.getKey),
inputType,
grouping.length,
tableConfig)) {
// gets the window-start and window-end position in the intermediate result.
val windowStartPos = rowTimeFieldPos
val windowEndPos = windowStartPos + 1
// grouping window
if (groupingKeys.length > 0) {
// create groupCombineFunction for combine the aggregations
val combineGroupFunction = createDataSetWindowAggregationCombineFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
grouping,
tableConfig)
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig,
isInputCombined = true)
mappedInput
.groupBy(groupingKeys: _*)
.sortGroup(rowTimeFieldPos, Order.ASCENDING)
.combineGroup(combineGroupFunction)
.groupBy(groupingKeys: _*)
.sortGroup(windowStartPos, Order.ASCENDING)
.sortGroup(windowEndPos, Order.ASCENDING)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
} else {
// non-grouping window
val mapPartitionFunction = createDataSetWindowAggregationMapPartitionFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
grouping,
tableConfig)
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig,
isInputCombined = true)
mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING)
.mapPartition(mapPartitionFunction)
.sortPartition(windowStartPos, Order.ASCENDING).setParallelism(1)
.sortPartition(windowEndPos, Order.ASCENDING).setParallelism(1)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
.asInstanceOf[DataSet[Row]]
}
// do non-incremental aggregation
} else {
// grouping window
if (groupingKeys.length > 0) {
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig)
mappedInput.groupBy(groupingKeys: _*)
.sortGroup(rowTimeFieldPos, Order.ASCENDING)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
} else {
// non-grouping window
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig)
mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING).setParallelism(1)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
.asInstanceOf[DataSet[Row]]
}
}
}
private def createEventTimeSlidingWindowDataSet(
generator: AggregationCodeGenerator,
inputDS: DataSet[Row],
isTimeWindow: Boolean,
size: Long,
slide: Long,
isParserCaseSensitive: Boolean,
tableConfig: TableConfig)
: DataSet[Row] = {
val input = inputNode.asInstanceOf[DataSetRel]
// create MapFunction for initializing the aggregations
// it aligns the rowtime for pre-tumbling in case of a time-window for partial aggregates
val mapFunction = createDataSetWindowPrepareMapFunction(
generator,
window,
namedAggregates,
grouping,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive,
tableConfig)
val mappedDataSet = inputDS
.map(mapFunction)
.name(prepareOperatorName)
val mapReturnType = mappedDataSet.getType
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
val groupingKeys = grouping.indices.toArray
// do partial aggregation if possible
val isPartial = doAllSupportPartialMerge(
namedAggregates.map(_.getKey),
inputType,
grouping.length,
tableConfig)
// only pre-tumble if it is worth it
val isLittleTumblingSize = determineLargestTumblingSize(size, slide) <= 1
val preparedDataSet = if (isTimeWindow) {
// time window
if (isPartial && !isLittleTumblingSize) {
// partial aggregates
val groupingKeysAndAlignedRowtime = groupingKeys :+ mapReturnType.getArity - 1
// create GroupReduceFunction
// for pre-tumbling and replicating/omitting the content for each pane
val prepareReduceFunction = createDataSetSlideWindowPrepareGroupReduceFunction(
generator,
window,
namedAggregates,
grouping,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive,
tableConfig)
mappedDataSet.asInstanceOf[DataSet[Row]]
.groupBy(groupingKeysAndAlignedRowtime: _*)
.reduceGroup(prepareReduceFunction) // pre-tumbles and replicates/omits
.name(prepareOperatorName)
} else {
// non-partial aggregates
// create FlatMapFunction
// for replicating/omitting the content for each pane
val prepareFlatMapFunction = createDataSetSlideWindowPrepareFlatMapFunction(
window,
namedAggregates,
grouping,
mapReturnType,
isParserCaseSensitive)
mappedDataSet
.flatMap(prepareFlatMapFunction) // replicates/omits
}
} else {
// count window
throw new UnsupportedOperationException(
"Count sliding group windows on event-time are currently not supported.")
}
val prepareReduceReturnType = preparedDataSet.getType
// create GroupReduceFunction for final aggregation and conversion to output row
val aggregateReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
generator,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig,
isInputCombined = false)
// gets the window-start position in the intermediate result.
val windowStartPos = prepareReduceReturnType.getArity - 1
val groupingKeysAndWindowStart = groupingKeys :+ windowStartPos
preparedDataSet
.groupBy(groupingKeysAndWindowStart: _*)
.reduceGroup(aggregateReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
}
private def prepareOperatorName: String = {
val aggString = aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties)
s"prepare select: ($aggString)"
}
private def aggregateOperatorName: String = {
val aggString = aggregationToString(
inputType,
grouping,
getRowType,
namedAggregates,
namedProperties)
if (grouping.length > 0) {
s"groupBy: (${groupingToString(inputType, grouping)}), " +
s"window: ($window), select: ($aggString)"
} else {
s"window: ($window), select: ($aggString)"
}
}
}