blob: a3f3e60649fc26e1458f9d0c32cbbe8a0c9f2ef3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.common.datastream;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
import org.apache.flink.iteration.datacache.nonkeyed.OperatorScopeManagedMemoryManager;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.common.datastream.sort.CoGroupOperator;
import org.apache.flink.ml.common.window.CountTumblingWindows;
import org.apache.flink.ml.common.window.EventTimeSessionWindows;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.common.window.GlobalWindows;
import org.apache.flink.ml.common.window.ProcessingTimeSessionWindows;
import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows;
import org.apache.flink.ml.common.window.Windows;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.streaming.api.datastream.AllWindowedStream;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.operators.InternalTimerService;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.TimestampedCollector;
import org.apache.flink.streaming.api.operators.Triggerable;
import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
import org.apache.flink.streaming.api.windowing.assigners.TumblingProcessingTimeWindows;
import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner;
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.util.Collector;
import org.apache.commons.collections.IteratorUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import static org.apache.flink.iteration.utils.DataStreamUtils.setManagedMemoryWeight;
/** Provides utility functions for {@link DataStream}. */
@Internal
public class DataStreamUtils {
/**
* Applies allReduceSum on the input data stream. The input data stream is supposed to contain
* up to one double array in each partition. The result data stream has the same parallelism as
* the input, where each partition contains one double array that sums all of the double arrays
* in the input data stream.
*
* <p>Note that we throw exception when one of the following two cases happen:
* <li>There exists one partition that contains more than one double array.
* <li>The length of the double array is not consistent among all partitions.
*
* @param input The input data stream.
* @return The result data stream.
*/
public static DataStream<double[]> allReduceSum(DataStream<double[]> input) {
return AllReduceImpl.allReduceSum(input);
}
/**
* Applies a {@link MapPartitionFunction} on a bounded data stream.
*
* @param input The input data stream.
* @param func The user defined mapPartition function.
* @param <IN> The class type of the input.
* @param <OUT> The class type of output.
* @return The result data stream.
*/
public static <IN, OUT> DataStream<OUT> mapPartition(
DataStream<IN> input, MapPartitionFunction<IN, OUT> func) {
TypeInformation<OUT> outType =
TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true);
return mapPartition(input, func, outType);
}
/**
* Applies a {@link MapPartitionFunction} on a bounded data stream.
*
* @param input The input data stream.
* @param func The user defined mapPartition function.
* @param outType The type information of the output.
* @param <IN> The class type of the input.
* @param <OUT> The class type of output.
* @return The result data stream.
*/
public static <IN, OUT> DataStream<OUT> mapPartition(
DataStream<IN> input,
MapPartitionFunction<IN, OUT> func,
TypeInformation<OUT> outType) {
func = input.getExecutionEnvironment().clean(func);
return input.transform("mapPartition", outType, new MapPartitionOperator<>(func))
.setParallelism(input.getParallelism());
}
/**
* Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
* one stream record and its parallelism is one.
*
* @param input The input data stream.
* @param func The user defined reduce function.
* @param <T> The class type of the input.
* @return The result data stream.
*/
public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
return reduce(input, func, input.getType());
}
/**
* Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
* one stream record and its parallelism is one.
*
* @param input The input data stream.
* @param func The user defined reduce function.
* @param outType The type information of the output.
* @param <T> The class type of the input.
* @return The result data stream.
*/
public static <T> DataStream<T> reduce(
DataStream<T> input, ReduceFunction<T> func, TypeInformation<T> outType) {
func = input.getExecutionEnvironment().clean(func);
DataStream<T> partialReducedStream =
input.transform("reduce", outType, new ReduceOperator<>(func))
.setParallelism(input.getParallelism());
if (partialReducedStream.getParallelism() == 1) {
return partialReducedStream;
} else {
return partialReducedStream
.transform("reduce", outType, new ReduceOperator<>(func))
.setParallelism(1);
}
}
/**
* Applies a {@link ReduceFunction} on a bounded keyed data stream. The output stream contains
* one stream record for each key.
*
* @param input The input keyed data stream.
* @param func The user defined reduce function.
* @param <T> The class type of input.
* @param <K> The key type of input.
* @return The result data stream.
*/
public static <T, K> DataStream<T> reduce(KeyedStream<T, K> input, ReduceFunction<T> func) {
return reduce(input, func, input.getType());
}
/**
* Applies a {@link ReduceFunction} on a bounded keyed data stream. The output stream contains
* one stream record for each key.
*
* @param input The input keyed data stream.
* @param func The user defined reduce function.
* @param outType The type information of the output.
* @param <T> The class type of input.
* @param <K> The key type of input.
* @return The result data stream.
*/
public static <T, K> DataStream<T> reduce(
KeyedStream<T, K> input, ReduceFunction<T> func, TypeInformation<T> outType) {
func = input.getExecutionEnvironment().clean(func);
return input.transform(
"Keyed Reduce",
outType,
new KeyedReduceOperator<>(
func, outType.createSerializer(input.getExecutionConfig())))
.setParallelism(input.getParallelism());
}
/**
* Aggregates the elements in each partition of the input bounded stream, and then merges the
* partial results of all partitions. The output stream contains the aggregated result and its
* parallelism is one.
*
* <p>Note: If the parallelism of the input stream is N, this method would invoke {@link
* AggregateFunction#createAccumulator()} N times and {@link AggregateFunction#merge(Object,
* Object)} N - 1 times. Thus the initial accumulator should be neutral (e.g. empty list for
* list concatenation or `0` for summation), otherwise the aggregation result would be affected
* by the parallelism of the input stream.
*
* @param input The input data stream.
* @param func The user defined aggregate function.
* @param <IN> The class type of the input.
* @param <ACC> The class type of the accumulated values.
* @param <OUT> The class type of the output values.
* @return The result data stream.
*/
public static <IN, ACC, OUT> DataStream<OUT> aggregate(
DataStream<IN> input, AggregateFunction<IN, ACC, OUT> func) {
TypeInformation<ACC> accType =
TypeExtractor.getAggregateFunctionAccumulatorType(
func, input.getType(), null, true);
TypeInformation<OUT> outType =
TypeExtractor.getAggregateFunctionReturnType(func, input.getType(), null, true);
return aggregate(input, func, accType, outType);
}
/**
* Aggregates the elements in each partition of the input bounded stream, and then merges the
* partial results of all partitions. The output stream contains the aggregated result and its
* parallelism is one.
*
* <p>Note: If the parallelism of the input stream is N, this method would invoke {@link
* AggregateFunction#createAccumulator()} N times and {@link AggregateFunction#merge(Object,
* Object)} N - 1 times. Thus the initial accumulator should be neutral (e.g. empty list for
* list concatenation or `0` for summation), otherwise the aggregation result would be affected
* by the parallelism of the input stream.
*
* @param input The input data stream.
* @param func The user defined aggregate function.
* @param accType The type of the accumulated values.
* @param outType The types of the output.
* @param <IN> The class type of the input.
* @param <ACC> The class type of the accumulated values.
* @param <OUT> The class type of the output values.
* @return The result data stream.
*/
public static <IN, ACC, OUT> DataStream<OUT> aggregate(
DataStream<IN> input,
AggregateFunction<IN, ACC, OUT> func,
TypeInformation<ACC> accType,
TypeInformation<OUT> outType) {
func = input.getExecutionEnvironment().clean(func);
DataStream<ACC> partialAggregatedStream =
input.transform(
"partialAggregate", accType, new PartialAggregateOperator<>(func, accType));
DataStream<OUT> aggregatedStream =
partialAggregatedStream.transform(
"aggregate", outType, new AggregateOperator<>(func, accType));
aggregatedStream.getTransformation().setParallelism(1);
return aggregatedStream;
}
/**
* Performs an approximate uniform sampling over the elements in a bounded data stream. The
* difference of probabilities of two data points been sampled is bounded by O(numSamples * p *
* p / (M * M)), where p is the parallelism of the input stream, M is the total number of data
* points that the input stream contains.
*
* <p>This method takes samples without replacement. If the number of elements in the stream is
* smaller than expected number of samples, all elements will be included in the sample.
*
* @param input The input data stream.
* @param numSamples The number of elements to be sampled.
* @param randomSeed The seed to randomly pick elements as sample.
* @return A data stream containing a list of the sampled elements.
*/
public static <T> DataStream<T> sample(DataStream<T> input, int numSamples, long randomSeed) {
int inputParallelism = input.getParallelism();
// The maximum difference of number of data points in each partition after calling
// `rebalance` is `inputParallelism`. As a result, extra `inputParallelism` data points are
// sampled for each partition in the first round.
int firstRoundNumSamples =
Math.min((numSamples / inputParallelism) + inputParallelism, numSamples);
return input.rebalance()
.transform(
"firstRoundSampling",
input.getType(),
new SamplingOperator<>(firstRoundNumSamples, randomSeed))
.setParallelism(inputParallelism)
.transform(
"secondRoundSampling",
input.getType(),
new SamplingOperator<>(numSamples, randomSeed))
.setParallelism(1)
.map(x -> x, input.getType())
.setParallelism(inputParallelism);
}
/**
* Creates windows from data in the non key grouped input stream and applies the given window
* function to each window.
*
* @param input The input data stream to be windowed and processed.
* @param windows The windowing strategy that defines how input data would be sliced into
* batches.
* @param function The user defined process function.
* @return The data stream that is the result of applying the window function to each window.
*/
public static <IN, OUT, W extends Window> SingleOutputStreamOperator<OUT> windowAllAndProcess(
DataStream<IN> input, Windows windows, ProcessAllWindowFunction<IN, OUT, W> function) {
function = input.getExecutionEnvironment().clean(function);
AllWindowedStream<IN, W> allWindowedStream = getAllWindowedStream(input, windows);
return allWindowedStream.process(function);
}
/**
* Creates windows from data in the non key grouped input stream and applies the given window
* function to each window.
*
* @param input The input data stream to be windowed and processed.
* @param windows The windowing strategy that defines how input data would be sliced into
* batches.
* @param function The user defined process function.
* @param outType The type information of the output.
* @return The data stream that is the result of applying the window function to each window.
*/
public static <IN, OUT, W extends Window> SingleOutputStreamOperator<OUT> windowAllAndProcess(
DataStream<IN> input,
Windows windows,
ProcessAllWindowFunction<IN, OUT, W> function,
TypeInformation<OUT> outType) {
function = input.getExecutionEnvironment().clean(function);
AllWindowedStream<IN, W> allWindowedStream = getAllWindowedStream(input, windows);
return allWindowedStream.process(function, outType);
}
/**
* A CoGroup transformation combines the elements of two {@link DataStream DataStreams} into one
* DataStream. It groups each DataStream individually on a key and gives groups of both
* DataStreams with equal keys together into a {@link
* org.apache.flink.api.common.functions.CoGroupFunction}. If a DataStream has a group with no
* matching key in the other DataStream, the CoGroupFunction is called with an empty group for
* the non-existing group.
*
* <p>The CoGroupFunction can iterate over the elements of both groups and return any number of
* elements including none.
*
* <p>NOTE: This method assumes both inputs are bounded.
*
* @param input1 The first data stream.
* @param input2 The second data stream.
* @param keySelector1 The KeySelector to be used for extracting the first input's key for
* partitioning.
* @param keySelector2 The KeySelector to be used for extracting the second input's key for
* partitioning.
* @param outTypeInformation The type information describing the output type.
* @param func The user-defined co-group function.
* @param <IN1> The class type of the first input.
* @param <IN2> The class type of the second input.
* @param <KEY> The class type of the key.
* @param <OUT> The class type of the output values.
* @return The result data stream.
*/
public static <IN1, IN2, KEY extends Serializable, OUT> DataStream<OUT> coGroup(
DataStream<IN1> input1,
DataStream<IN2> input2,
KeySelector<IN1, KEY> keySelector1,
KeySelector<IN2, KEY> keySelector2,
TypeInformation<OUT> outTypeInformation,
CoGroupFunction<IN1, IN2, OUT> func) {
func = input1.getExecutionEnvironment().clean(func);
DataStream<OUT> result =
input1.connect(input2)
.keyBy(keySelector1, keySelector2)
.transform(
"CoGroupOperator", outTypeInformation, new CoGroupOperator<>(func))
.setParallelism(Math.max(input1.getParallelism(), input2.getParallelism()));
setManagedMemoryWeight(result, 100);
return result;
}
@SuppressWarnings({"rawtypes", "unchecked"})
private static <IN, W extends Window> AllWindowedStream<IN, W> getAllWindowedStream(
DataStream<IN> input, Windows windows) {
if (windows instanceof CountTumblingWindows) {
long countWindowSize = ((CountTumblingWindows) windows).getSize();
return (AllWindowedStream<IN, W>) input.countWindowAll(countWindowSize);
} else {
return input.windowAll((WindowAssigner) getDataStreamTimeWindowAssigner(windows));
}
}
private static WindowAssigner<Object, TimeWindow> getDataStreamTimeWindowAssigner(
Windows windows) {
if (windows instanceof GlobalWindows) {
return EndOfStreamWindows.get();
} else if (windows instanceof EventTimeTumblingWindows) {
return TumblingEventTimeWindows.of(
getStreamWindowTime(((EventTimeTumblingWindows) windows).getSize()));
} else if (windows instanceof ProcessingTimeTumblingWindows) {
return TumblingProcessingTimeWindows.of(
getStreamWindowTime(((ProcessingTimeTumblingWindows) windows).getSize()));
} else if (windows instanceof EventTimeSessionWindows) {
return org.apache.flink.streaming.api.windowing.assigners.EventTimeSessionWindows
.withGap(getStreamWindowTime(((EventTimeSessionWindows) windows).getGap()));
} else if (windows instanceof ProcessingTimeSessionWindows) {
return org.apache.flink.streaming.api.windowing.assigners.ProcessingTimeSessionWindows
.withGap(
getStreamWindowTime(((ProcessingTimeSessionWindows) windows).getGap()));
} else {
throw new UnsupportedOperationException(
String.format(
"Unsupported Windows subclass: %s", windows.getClass().getName()));
}
}
private static org.apache.flink.streaming.api.windowing.time.Time getStreamWindowTime(
Time time) {
return org.apache.flink.streaming.api.windowing.time.Time.of(
time.getSize(), time.getUnit());
}
/**
* A stream operator to apply {@link MapPartitionFunction} on each partition of the input
* bounded data stream.
*/
private static class MapPartitionOperator<IN, OUT>
extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
private ListStateWithCache<IN> valuesState;
public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
super(mapPartitionFunc);
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
final StreamTask<?, ?> containingTask = getContainingTask();
final OperatorID operatorID = config.getOperatorID();
final OperatorScopeManagedMemoryManager manager =
OperatorScopeManagedMemoryManager.getOrCreate(containingTask, operatorID);
final String stateKey = "values-state";
manager.register(stateKey, 1.);
valuesState =
new ListStateWithCache<>(
getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader()),
stateKey,
context,
this);
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
super.snapshotState(context);
valuesState.snapshotState(context);
}
@Override
public void processElement(StreamRecord<IN> input) throws Exception {
valuesState.add(input.getValue());
}
@Override
public void endInput() throws Exception {
userFunction.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
valuesState.clear();
}
}
/** A stream operator to apply {@link ReduceFunction} on the input bounded data stream. */
private static class ReduceOperator<T> extends AbstractUdfStreamOperator<T, ReduceFunction<T>>
implements OneInputStreamOperator<T, T>, BoundedOneInput {
/** The temp result of the reduce function. */
private T result;
private ListState<T> state;
public ReduceOperator(ReduceFunction<T> userFunction) {
super(userFunction);
}
@Override
public void endInput() {
if (result != null) {
output.collect(new StreamRecord<>(result));
}
}
@Override
public void processElement(StreamRecord<T> streamRecord) throws Exception {
if (result == null) {
result = streamRecord.getValue();
} else {
result = userFunction.reduce(streamRecord.getValue(), result);
}
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
state =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<>(
"state",
getOperatorConfig()
.getTypeSerializerIn(
0, getClass().getClassLoader())));
result = OperatorStateUtils.getUniqueElement(state, "state").orElse(null);
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
super.snapshotState(context);
state.clear();
if (result != null) {
state.add(result);
}
}
}
/**
* A stream operator to apply {@link ReduceFunction} on the input bounded keyed data stream.
*
* <p>Note: this class is a copy of {@link
* org.apache.flink.streaming.api.operators.BatchGroupedReduceOperator} in case of unexpected
* changes of its implementation.
*/
private static class KeyedReduceOperator<IN, KEY>
extends AbstractUdfStreamOperator<IN, ReduceFunction<IN>>
implements OneInputStreamOperator<IN, IN>, Triggerable<KEY, VoidNamespace> {
private static final long serialVersionUID = 1L;
private static final String STATE_NAME = "_op_state";
private transient ValueState<IN> values;
private final TypeSerializer<IN> serializer;
private InternalTimerService<VoidNamespace> timerService;
public KeyedReduceOperator(ReduceFunction<IN> reducer, TypeSerializer<IN> serializer) {
super(reducer);
this.serializer = serializer;
}
@Override
public void open() throws Exception {
super.open();
ValueStateDescriptor<IN> stateId = new ValueStateDescriptor<>(STATE_NAME, serializer);
values = getPartitionedState(stateId);
timerService =
getInternalTimerService("end-key-timers", new VoidNamespaceSerializer(), this);
}
@Override
public void processElement(StreamRecord<IN> element) throws Exception {
IN value = element.getValue();
IN currentValue = values.value();
if (currentValue == null) {
// Registers a timer for emitting the result at the end when this is the
// first input for this key.
timerService.registerEventTimeTimer(VoidNamespace.INSTANCE, Long.MAX_VALUE);
} else {
// Otherwise, reduces things.
value = userFunction.reduce(currentValue, value);
}
values.update(value);
}
@Override
public void onEventTime(InternalTimer<KEY, VoidNamespace> timer) throws Exception {
IN currentValue = values.value();
if (currentValue != null) {
output.collect(new StreamRecord<>(currentValue, Long.MAX_VALUE));
}
}
@Override
public void onProcessingTime(InternalTimer<KEY, VoidNamespace> timer) throws Exception {}
}
/**
* A stream operator to apply {@link AggregateFunction#add(IN, ACC)} on each partition of the
* input bounded data stream.
*/
private static class PartialAggregateOperator<IN, ACC, OUT>
extends AbstractUdfStreamOperator<ACC, AggregateFunction<IN, ACC, OUT>>
implements OneInputStreamOperator<IN, ACC>, BoundedOneInput {
/** Type information of the accumulated result. */
private final TypeInformation<ACC> accType;
/** The accumulated result of the aggregate function in one partition. */
private ACC acc;
/** State of acc. */
private ListState<ACC> accState;
public PartialAggregateOperator(
AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
super(userFunction);
this.accType = accType;
}
@Override
public void endInput() {
output.collect(new StreamRecord<>(acc));
}
@Override
public void processElement(StreamRecord<IN> streamRecord) throws Exception {
acc = userFunction.add(streamRecord.getValue(), acc);
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
accState =
context.getOperatorStateStore()
.getListState(new ListStateDescriptor<>("accState", accType));
acc =
OperatorStateUtils.getUniqueElement(accState, "accState")
.orElse(userFunction.createAccumulator());
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
super.snapshotState(context);
accState.clear();
accState.add(acc);
}
}
/**
* A stream operator to apply {@link AggregateFunction#merge(ACC, ACC)} and {@link
* AggregateFunction#getResult(ACC)} on the input bounded data stream.
*/
private static class AggregateOperator<IN, ACC, OUT>
extends AbstractUdfStreamOperator<OUT, AggregateFunction<IN, ACC, OUT>>
implements OneInputStreamOperator<ACC, OUT>, BoundedOneInput {
/** Type information of the accumulated result. */
private final TypeInformation<ACC> accType;
/** The accumulated result of the aggregate function in the final partition. */
private ACC acc;
/** State of acc. */
private ListState<ACC> accState;
public AggregateOperator(
AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
super(userFunction);
this.accType = accType;
}
@Override
public void endInput() {
output.collect(new StreamRecord<>(userFunction.getResult(acc)));
}
@Override
public void processElement(StreamRecord<ACC> streamRecord) throws Exception {
if (acc == null) {
acc = streamRecord.getValue();
} else {
acc = userFunction.merge(streamRecord.getValue(), acc);
}
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
accState =
context.getOperatorStateStore()
.getListState(new ListStateDescriptor<>("accState", accType));
acc = OperatorStateUtils.getUniqueElement(accState, "accState").orElse(null);
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
super.snapshotState(context);
accState.clear();
if (acc != null) {
accState.add(acc);
}
}
}
/**
* Splits the input data into global batches of batchSize. After splitting, each global batch is
* further split into local batches for downstream operators with each worker has one batch.
*/
public static <T> DataStream<T[]> generateBatchData(
DataStream<T> inputData, final int downStreamParallelism, int batchSize) {
return inputData
.countWindowAll(batchSize)
.apply(new GlobalBatchCreator<>())
.flatMap(new GlobalBatchSplitter<>(downStreamParallelism))
.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f0)
.map(
new MapFunction<Tuple2<Integer, T[]>, T[]>() {
@Override
public T[] map(Tuple2<Integer, T[]> integerTuple2) throws Exception {
return integerTuple2.f1;
}
});
}
/** Splits the input data into global batches. */
private static class GlobalBatchCreator<T> implements AllWindowFunction<T, T[], GlobalWindow> {
@Override
public void apply(GlobalWindow timeWindow, Iterable<T> iterable, Collector<T[]> collector) {
List<T> points = IteratorUtils.toList(iterable.iterator());
collector.collect(points.toArray((T[]) new Object[0]));
}
}
/**
* An operator that splits a global batch into evenly-sized local batches, and distributes them
* to downstream operator.
*/
private static class GlobalBatchSplitter<T>
implements FlatMapFunction<T[], Tuple2<Integer, T[]>> {
private final int downStreamParallelism;
public GlobalBatchSplitter(int downStreamParallelism) {
this.downStreamParallelism = downStreamParallelism;
}
@Override
public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> collector) {
int div = values.length / downStreamParallelism;
int mod = values.length % downStreamParallelism;
int offset = 0;
int i = 0;
int size = div + 1;
for (; i < mod; i++) {
collector.collect(Tuple2.of(i, Arrays.copyOfRange(values, offset, offset + size)));
offset += size;
}
size = div;
for (; i < downStreamParallelism; i++) {
collector.collect(Tuple2.of(i, Arrays.copyOfRange(values, offset, offset + size)));
offset += size;
}
}
}
/*
* A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
*/
private static class SamplingOperator<T> extends AbstractStreamOperator<T>
implements OneInputStreamOperator<T, T>, BoundedOneInput {
private final int numSamples;
private final Random random;
private ListState<T> samplesState;
private List<T> samples;
private ListState<Integer> countState;
private int count;
SamplingOperator(int numSamples, long randomSeed) {
this.numSamples = numSamples;
this.random = new Random(randomSeed);
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
ListStateDescriptor<T> samplesDescriptor =
new ListStateDescriptor<>(
"samplesState",
getOperatorConfig()
.getTypeSerializerIn(0, getClass().getClassLoader()));
samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
samples = new ArrayList<>(numSamples);
samplesState.get().forEach(samples::add);
ListStateDescriptor<Integer> countDescriptor =
new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
countState = context.getOperatorStateStore().getListState(countDescriptor);
Iterator<Integer> countIterator = countState.get().iterator();
if (countIterator.hasNext()) {
count = countIterator.next();
} else {
count = 0;
}
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
super.snapshotState(context);
samplesState.update(samples);
countState.update(Collections.singletonList(count));
}
@Override
public void processElement(StreamRecord<T> streamRecord) throws Exception {
T value = streamRecord.getValue();
count++;
if (samples.size() < numSamples) {
samples.add(value);
} else {
int index = random.nextInt(count);
if (index < numSamples) {
samples.set(index, value);
}
}
}
@Override
public void endInput() throws Exception {
for (T sample : samples) {
output.collect(new StreamRecord<>(sample));
}
}
}
}