[FLINK-32704] Supports spilling to disk when feedback channel memory buffer is full
This closes #248.
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 88ff0d3..a3f3e60 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -36,7 +36,6 @@
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
-import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
import org.apache.flink.iteration.datacache.nonkeyed.OperatorScopeManagedMemoryManager;
import org.apache.flink.iteration.operator.OperatorStateUtils;
@@ -75,7 +74,6 @@
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
-import org.apache.flink.table.api.TableException;
import org.apache.flink.util.Collector;
import org.apache.commons.collections.IteratorUtils;
@@ -86,9 +84,10 @@
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
-import java.util.Optional;
import java.util.Random;
+import static org.apache.flink.iteration.utils.DataStreamUtils.setManagedMemoryWeight;
+
/** Provides utility functions for {@link DataStream}. */
@Internal
public class DataStreamUtils {
@@ -322,29 +321,6 @@
}
/**
- * Sets {Transformation#declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase, int)}
- * using the given bytes for {@link ManagedMemoryUseCase#OPERATOR}.
- *
- * <p>This method is in reference to Flink's ExecNodeUtil.setManagedMemoryWeight. The provided
- * bytes should be in the same scale as existing usage in Flink, for example,
- * StreamExecWindowAggregate.WINDOW_AGG_MEMORY_RATIO.
- */
- public static <T> void setManagedMemoryWeight(DataStream<T> dataStream, long memoryBytes) {
- if (memoryBytes > 0) {
- final int weightInMebibyte = Math.max(1, (int) (memoryBytes >> 20));
- final Optional<Integer> previousWeight =
- dataStream
- .getTransformation()
- .declareManagedMemoryUseCaseAtOperatorScope(
- ManagedMemoryUseCase.OPERATOR, weightInMebibyte);
- if (previousWeight.isPresent()) {
- throw new TableException(
- "Managed memory weight has been set, this should not happen.");
- }
- }
- }
-
- /**
* Creates windows from data in the non key grouped input stream and applies the given window
* function to each window.
*
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
index 060776e..a758a5b 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
@@ -20,6 +20,7 @@
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
@@ -46,7 +47,7 @@
}
public TerminateOnMaxIterOrTol(Double tol) {
- this.maxIter = Integer.MAX_VALUE;
+ this.maxIter = IterationRecord.END_EPOCH_WATERMARK;
this.tol = tol;
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/ListStateWithCacheTest.java b/flink-ml-core/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/ListStateWithCacheTest.java
index 372b4a1..5e3a8c1 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/ListStateWithCacheTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/ListStateWithCacheTest.java
@@ -48,6 +48,8 @@
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
+import static org.apache.flink.iteration.utils.DataStreamUtils.setManagedMemoryWeight;
+
/** Tests {@link ListStateWithCache}. */
public class ListStateWithCacheTest {
@@ -87,7 +89,7 @@
env.fromSequence(1, n).map(d -> RandomStringUtils.randomAlphabetic(1024 * 1024));
DataStream<Integer> counter =
data.transform("cache", Types.INT, new CacheDataOperator(weights));
- DataStreamUtils.setManagedMemoryWeight(counter, 100);
+ setManagedMemoryWeight(counter, 100);
DataStream<Integer> sum = DataStreamUtils.reduce(counter, Integer::sum);
sum.addSink(
new SinkFunction<Integer>() {
diff --git a/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/Iterations.java b/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/Iterations.java
index 8f8fc3e..d3b985c 100644
--- a/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/Iterations.java
+++ b/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/Iterations.java
@@ -33,6 +33,7 @@
import org.apache.flink.iteration.operator.allround.AllRoundOperatorWrapper;
import org.apache.flink.iteration.operator.perround.PerRoundOperatorWrapper;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.iteration.utils.DataStreamUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -456,21 +457,25 @@
return new DataStreamList(
map(
inputStreams,
- (index, dataStream) ->
- ((SingleOutputStreamOperator<IterationRecord<?>>) dataStream)
- .transform(
- "head-"
- + variableStreams
- .get(index)
- .getTransformation()
- .getName(),
- (IterationRecordTypeInfo) dataStream.getType(),
- new HeadOperatorFactory(
- iterationId,
- startHeaderIndex + index,
- isCriteriaStream,
- totalInitVariableParallelism))
- .setParallelism(dataStream.getParallelism())));
+ (index, dataStream) -> {
+ DataStream ds =
+ ((SingleOutputStreamOperator<IterationRecord<?>>) dataStream)
+ .transform(
+ "head-"
+ + variableStreams
+ .get(index)
+ .getTransformation()
+ .getName(),
+ (IterationRecordTypeInfo) dataStream.getType(),
+ new HeadOperatorFactory(
+ iterationId,
+ startHeaderIndex + index,
+ isCriteriaStream,
+ totalInitVariableParallelism))
+ .setParallelism(dataStream.getParallelism());
+ DataStreamUtils.setManagedMemoryWeight(ds, 100);
+ return ds;
+ }));
}
private static DataStreamList addTails(
diff --git a/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java b/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index 3d67b7f..61350a4 100644
--- a/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++ b/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -21,21 +21,28 @@
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannel;
import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordSerializer;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
-import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.graph.StreamConfig.NetworkInputConfig;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.function.SupplierWithException;
@@ -43,7 +50,6 @@
import java.io.IOException;
import java.nio.file.Paths;
-import java.util.Arrays;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Executor;
@@ -66,15 +72,38 @@
/** Registers the specified {@code feedbackConsumer} to the {@code feedbackChannel}. */
public static <V> void registerFeedbackConsumer(
- FeedbackChannel<V> feedbackChannel,
+ SpillableFeedbackChannel<V> feedbackChannel,
FeedbackConsumer<V> feedbackConsumer,
Executor executor) {
- ReflectionUtils.callMethod(
- feedbackChannel,
- FeedbackChannel.class,
- "registerConsumer",
- Arrays.asList(FeedbackConsumer.class, Executor.class),
- Arrays.asList(feedbackConsumer, executor));
+ feedbackChannel.registerConsumer(feedbackConsumer, executor);
+ }
+
+ /** Initialize the given {@code feedbackChannel}. */
+ public static void initializeFeedbackChannel(
+ SpillableFeedbackChannel feedbackChannel, AbstractStreamOperator<?> operator)
+ throws MemoryAllocationException {
+ StreamTask<?, ?> containingTask = operator.getContainingTask();
+ TypeSerializer<StreamRecord<IterationRecord<?>>> serializer =
+ new StreamElementSerializer(
+ operator.getOperatorConfig()
+ .getTypeSerializerIn(0, operator.getUserCodeClassloader()));
+ MemorySize totalManagedMemory =
+ new MemorySize(containingTask.getEnvironment().getMemoryManager().getMemorySize());
+ double fraction =
+ containingTask
+ .getConfiguration()
+ .getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.OPERATOR,
+ operator.getRuntimeContext()
+ .getTaskManagerRuntimeInfo()
+ .getConfiguration(),
+ operator.getRuntimeContext().getUserCodeClassLoader());
+ long feedbackChannelBufferSize = totalManagedMemory.multiply(fraction).getBytes();
+ feedbackChannel.initialize(
+ containingTask.getEnvironment().getIOManager(),
+ containingTask.getEnvironment().getMemoryManager(),
+ serializer,
+ feedbackChannelBufferSize);
}
public static <T> void processOperatorOrUdfIfSatisfy(
diff --git a/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
index 4ae8885..4260830 100644
--- a/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
@@ -213,7 +213,7 @@
// the onEpochWatermarkIncrement must be from 0 and increment by 1 each time, except
// for the last round.
try {
- if (epochWatermark < Integer.MAX_VALUE) {
+ if (epochWatermark < IterationRecord.END_EPOCH_WATERMARK) {
S wrappedOperator = wrappedOperators.remove(epochWatermark);
if (wrappedOperator != null) {
closeStreamOperator(wrappedOperator, epochWatermark, epochWatermark);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/IterationRecord.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/IterationRecord.java
index 4f53fb4..dd68afb 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/IterationRecord.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/IterationRecord.java
@@ -22,6 +22,8 @@
/** The wrapper for the records in iterative stream. */
public class IterationRecord<T> implements Cloneable {
+ public static final int END_EPOCH_WATERMARK = Integer.MAX_VALUE - 1;
+
/** The type of iteration records. */
public enum Type {
RECORD,
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/Iterations.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/Iterations.java
index 8f8fc3e..d3b985c 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/Iterations.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/Iterations.java
@@ -33,6 +33,7 @@
import org.apache.flink.iteration.operator.allround.AllRoundOperatorWrapper;
import org.apache.flink.iteration.operator.perround.PerRoundOperatorWrapper;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.iteration.utils.DataStreamUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -456,21 +457,25 @@
return new DataStreamList(
map(
inputStreams,
- (index, dataStream) ->
- ((SingleOutputStreamOperator<IterationRecord<?>>) dataStream)
- .transform(
- "head-"
- + variableStreams
- .get(index)
- .getTransformation()
- .getName(),
- (IterationRecordTypeInfo) dataStream.getType(),
- new HeadOperatorFactory(
- iterationId,
- startHeaderIndex + index,
- isCriteriaStream,
- totalInitVariableParallelism))
- .setParallelism(dataStream.getParallelism())));
+ (index, dataStream) -> {
+ DataStream ds =
+ ((SingleOutputStreamOperator<IterationRecord<?>>) dataStream)
+ .transform(
+ "head-"
+ + variableStreams
+ .get(index)
+ .getTransformation()
+ .getName(),
+ (IterationRecordTypeInfo) dataStream.getType(),
+ new HeadOperatorFactory(
+ iterationId,
+ startHeaderIndex + index,
+ isCriteriaStream,
+ totalInitVariableParallelism))
+ .setParallelism(dataStream.getParallelism());
+ DataStreamUtils.setManagedMemoryWeight(ds, 100);
+ return ds;
+ }));
}
private static DataStreamList addTails(
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
index d790729..2c60486 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
@@ -125,7 +125,7 @@
@SuppressWarnings({"unchecked", "rawtypes"})
protected void notifyEpochWatermarkIncrement(IterationListener<?> listener, int epochWatermark)
throws Exception {
- if (epochWatermark != Integer.MAX_VALUE) {
+ if (epochWatermark != IterationRecord.END_EPOCH_WATERMARK) {
listener.onEpochWatermarkIncremented(
epochWatermark,
iterationContext,
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index 23dfa04..4517dec 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -38,6 +38,8 @@
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.iteration.operator.event.TerminatingOnInitializeEvent;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannel;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannelBroker;
import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
import org.apache.flink.iteration.operator.headprocessor.HeadOperatorState;
import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
@@ -57,14 +59,13 @@
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StateSnapshotContext;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
import org.apache.flink.statefun.flink.core.feedback.SubtaskFeedbackKey;
@@ -127,6 +128,8 @@
private final int feedbackIndex;
+ private SpillableFeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel;
+
private final boolean isCriteriaStream;
private final OperatorEventGateway operatorEventGateway;
@@ -422,16 +425,19 @@
}
}
- private void registerFeedbackConsumer(Executor mailboxExecutor) {
+ private void registerFeedbackConsumer(Executor mailboxExecutor)
+ throws MemoryAllocationException {
int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
int attemptNum = getRuntimeContext().getAttemptNumber();
FeedbackKey<StreamRecord<IterationRecord<?>>> feedbackKey =
OperatorUtils.createFeedbackKey(iterationId, feedbackIndex);
SubtaskFeedbackKey<StreamRecord<IterationRecord<?>>> key =
feedbackKey.withSubTaskIndex(indexOfThisSubtask, attemptNum);
- FeedbackChannelBroker broker = FeedbackChannelBroker.get();
- FeedbackChannel<StreamRecord<IterationRecord<?>>> channel = broker.getChannel(key);
- OperatorUtils.registerFeedbackConsumer(channel, this, mailboxExecutor);
+ SpillableFeedbackChannelBroker broker = SpillableFeedbackChannelBroker.get();
+ this.feedbackChannel =
+ broker.getChannel(
+ key, channel -> OperatorUtils.initializeFeedbackChannel(channel, this));
+ OperatorUtils.registerFeedbackConsumer(feedbackChannel, this, mailboxExecutor);
}
private List<AbstractEvent> parseInputChannelEvents(InputChannel inputChannel)
@@ -499,6 +505,11 @@
}
@VisibleForTesting
+ public SpillableFeedbackChannel getFeedbackChannel() {
+ return feedbackChannel;
+ }
+
+ @VisibleForTesting
enum HeadOperatorStatus {
RUNNING,
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index 7b81e52..f35604b 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -21,21 +21,28 @@
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannel;
import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordSerializer;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
-import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.graph.StreamConfig.NetworkInputConfig;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.function.SupplierWithException;
@@ -43,7 +50,6 @@
import java.io.IOException;
import java.nio.file.Paths;
-import java.util.Arrays;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Executor;
@@ -65,15 +71,38 @@
/** Registers the specified {@code feedbackConsumer} to the {@code feedbackChannel}. */
public static <V> void registerFeedbackConsumer(
- FeedbackChannel<V> feedbackChannel,
+ SpillableFeedbackChannel<V> feedbackChannel,
FeedbackConsumer<V> feedbackConsumer,
Executor executor) {
- ReflectionUtils.callMethod(
- feedbackChannel,
- FeedbackChannel.class,
- "registerConsumer",
- Arrays.asList(FeedbackConsumer.class, Executor.class),
- Arrays.asList(feedbackConsumer, executor));
+ feedbackChannel.registerConsumer(feedbackConsumer, executor);
+ }
+
+ /** Initialize the given {@code feedbackChannel}. */
+ public static void initializeFeedbackChannel(
+ SpillableFeedbackChannel feedbackChannel, AbstractStreamOperator<?> operator)
+ throws MemoryAllocationException {
+ StreamTask<?, ?> containingTask = operator.getContainingTask();
+ TypeSerializer<StreamRecord<IterationRecord<?>>> serializer =
+ new StreamElementSerializer(
+ operator.getOperatorConfig()
+ .getTypeSerializerIn(0, operator.getUserCodeClassloader()));
+ MemorySize totalManagedMemory =
+ new MemorySize(containingTask.getEnvironment().getMemoryManager().getMemorySize());
+ double fraction =
+ containingTask
+ .getConfiguration()
+ .getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.OPERATOR,
+ operator.getRuntimeContext()
+ .getTaskManagerRuntimeInfo()
+ .getConfiguration(),
+ operator.getRuntimeContext().getUserCodeClassLoader());
+ long feedbackChannelBufferSize = totalManagedMemory.multiply(fraction).getBytes();
+ feedbackChannel.initialize(
+ containingTask.getEnvironment().getIOManager(),
+ containingTask.getEnvironment().getMemoryManager(),
+ serializer,
+ feedbackChannelBufferSize);
}
public static <T> void processOperatorOrUdfIfSatisfy(
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
index d0e6971..f2cc8ea 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
@@ -51,7 +51,7 @@
reusable.replace(streamRecord.getValue().getValue(), streamRecord.getTimestamp());
output.collect(reusable);
} else if (streamRecord.getValue().getType() == Type.EPOCH_WATERMARK
- && streamRecord.getValue().getEpoch() == Integer.MAX_VALUE) {
+ && streamRecord.getValue().getEpoch() == IterationRecord.END_EPOCH_WATERMARK) {
output.emitWatermark(new Watermark(Long.MAX_VALUE));
}
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
index 133a3d5..7c45a81 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
@@ -266,7 +266,7 @@
dataCacheWriter.finish();
emitEpochWatermark(epochWatermark);
return;
- } else if (epochWatermark == Integer.MAX_VALUE) {
+ } else if (epochWatermark == IterationRecord.END_EPOCH_WATERMARK) {
emitEpochWatermark(epochWatermark);
return;
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/TailOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
index 2e26142..7702414 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
@@ -22,8 +22,8 @@
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.checkpoint.Checkpoints;
import org.apache.flink.iteration.checkpoint.CheckpointsBroker;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannel;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannelBroker;
import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
import org.apache.flink.statefun.flink.core.feedback.SubtaskFeedbackKey;
import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -51,7 +51,7 @@
/** We distinguish how the record is processed according to if objectReuse is enabled. */
private transient Consumer<StreamRecord<IterationRecord<?>>> recordConsumer;
- private transient FeedbackChannel<StreamRecord<IterationRecord<?>>> channel;
+ private transient SpillableFeedbackChannel<StreamRecord<IterationRecord<?>>> channel;
public TailOperator(IterationID iterationId, int feedbackIndex) {
this.iterationId = Objects.requireNonNull(iterationId);
@@ -78,8 +78,8 @@
SubtaskFeedbackKey<StreamRecord<IterationRecord<?>>> key =
feedbackKey.withSubTaskIndex(indexOfThisSubtask, attemptNum);
- FeedbackChannelBroker broker = FeedbackChannelBroker.get();
- this.channel = broker.getChannel(key);
+ SpillableFeedbackChannelBroker broker = SpillableFeedbackChannelBroker.get();
+ this.channel = broker.getChannel(key, null);
this.recordConsumer =
getExecutionConfig().isObjectReuseEnabled()
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index bd503e9..511c07f 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -181,14 +181,14 @@
@Override
public void finish() throws Exception {
- setIterationContextRound(Integer.MAX_VALUE);
+ setIterationContextRound(IterationRecord.END_EPOCH_WATERMARK);
wrappedOperator.finish();
clearIterationContextRound();
}
@Override
public void close() throws Exception {
- setIterationContextRound(Integer.MAX_VALUE);
+ setIterationContextRound(IterationRecord.END_EPOCH_WATERMARK);
wrappedOperator.close();
clearIterationContextRound();
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
index 949419f..509d8a7 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
@@ -83,7 +83,7 @@
super.endInput(i);
if (wrappedOperator instanceof BoundedMultiInput) {
- setIterationContextRound(Integer.MAX_VALUE);
+ setIterationContextRound(IterationRecord.END_EPOCH_WATERMARK);
((BoundedMultiInput) wrappedOperator).endInput(i);
clearIterationContextRound();
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
index bbfd0b3..e65b1e2 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
@@ -87,7 +87,7 @@
@Override
public void endInput() throws Exception {
if (wrappedOperator instanceof BoundedOneInput) {
- setIterationContextRound(Integer.MAX_VALUE);
+ setIterationContextRound(IterationRecord.END_EPOCH_WATERMARK);
((BoundedOneInput) wrappedOperator).endInput();
clearIterationContextRound();
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
index a91f6c4..5d4f9b4 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
@@ -116,7 +116,7 @@
super.endInput(i);
if (wrappedOperator instanceof BoundedMultiInput) {
- setIterationContextRound(Integer.MAX_VALUE);
+ setIterationContextRound(IterationRecord.END_EPOCH_WATERMARK);
((BoundedMultiInput) wrappedOperator).endInput(i);
clearIterationContextRound();
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/MpscQueue.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/MpscQueue.java
new file mode 100644
index 0000000..b9d9e1b
--- /dev/null
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/MpscQueue.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.feedback;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.util.EmptyMutableObjectIterator;
+import org.apache.flink.statefun.flink.core.queue.Lock;
+import org.apache.flink.statefun.flink.core.queue.Locks;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.MutableObjectIterator;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+/**
+ * Multi producers single consumer fifo queue.
+ *
+ * @param <T> The element type.
+ */
+@Internal
+final class MpscQueue<T> implements Closeable {
+ private final Lock lock = Locks.spinLock();
+
+ private SpillableFeedbackQueue<T> activeQueue;
+ private SpillableFeedbackQueue<T> standByQueue;
+
+ MpscQueue(
+ IOManager ioManager,
+ MemoryManager memoryManager,
+ TypeSerializer<T> serializer,
+ long inMemoryBufferSize)
+ throws MemoryAllocationException {
+ this.activeQueue =
+ new SpillableFeedbackQueue<>(
+ ioManager, memoryManager, serializer, inMemoryBufferSize / 2);
+ this.standByQueue =
+ new SpillableFeedbackQueue<>(
+ ioManager, memoryManager, serializer, inMemoryBufferSize / 2);
+ }
+
+ /**
+ * Adds an element to this (unbound) queue.
+ *
+ * @param element the element to add.
+ * @return the number of elements in the queue after the addition.
+ */
+ long add(T element) {
+ Preconditions.checkState(element instanceof StreamRecord);
+ final Lock lock = this.lock;
+ lock.lockUninterruptibly();
+ try {
+ SpillableFeedbackQueue<T> active = this.activeQueue;
+
+ active.add(element);
+ return active.size();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ /**
+ * Atomically drains the queue.
+ *
+ * @return a batch of elements that obtained atomically from that queue.
+ */
+ MutableObjectIterator<T> drainAll() {
+ final Lock lock = this.lock;
+ lock.lockUninterruptibly();
+ try {
+ final SpillableFeedbackQueue<T> ready = this.activeQueue;
+ if (ready.size() == 0) {
+ return EmptyMutableObjectIterator.get();
+ }
+ this.activeQueue = this.standByQueue;
+ this.standByQueue = ready;
+ return ready.iterate();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ void resetStandBy() throws Exception {
+ final Lock lock = this.lock;
+ lock.lockUninterruptibly();
+ try {
+ standByQueue.reset();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ public void close() throws IOException {
+ final Lock lock = this.lock;
+ lock.lockUninterruptibly();
+ try {
+ activeQueue.release();
+ standByQueue.release();
+ } finally {
+ lock.unlock();
+ }
+ }
+}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackChannel.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackChannel.java
new file mode 100644
index 0000000..94df726
--- /dev/null
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackChannel.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.feedback;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.statefun.flink.core.feedback.SubtaskFeedbackKey;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.MutableObjectIterator;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Closeable;
+import java.util.Objects;
+import java.util.concurrent.Executor;
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * Single producer, single consumer channel, which can spill the records to disk when the in-memory
+ * buffer is full.
+ */
+@Internal
+public final class SpillableFeedbackChannel<T> implements Closeable {
+
+ /** The key that used to identify this channel. */
+ private final SubtaskFeedbackKey<T> key;
+
+ /** A single registered consumer. */
+ private final AtomicReference<ConsumerTask<T>> consumerRef = new AtomicReference<>();
+
+ /** The underlying queue used to hold the feedback results. */
+ private MpscQueue<T> queue;
+
+ SpillableFeedbackChannel(SubtaskFeedbackKey<T> key) {
+ this.key = Objects.requireNonNull(key);
+ }
+
+ public void initialize(
+ IOManager ioManager,
+ MemoryManager memoryManager,
+ TypeSerializer<T> serializer,
+ long inMemoryBufferSize)
+ throws MemoryAllocationException {
+ this.queue = new MpscQueue<>(ioManager, memoryManager, serializer, inMemoryBufferSize);
+ }
+
+ /** Adds a feedback result to this channel. */
+ public void put(T element) {
+ if (!isInitialized()) {
+ Preconditions.checkState(
+ queue != null,
+ "The SpillableFeedbackChannel has not been initialized, "
+ + "please call SpillableFeedbackChannel#initialize first");
+ }
+ if (queue.add(element) == 1) {
+ final ConsumerTask<T> consumer = consumerRef.get();
+ if (consumer != null) {
+ consumer.scheduleDrainAll();
+ }
+ }
+ }
+
+ /**
+ * Register a feedback iteration consumer.
+ *
+ * @param consumer the feedback events consumer.
+ * @param executor the executor to schedule feedback consumption on.
+ */
+ public void registerConsumer(final FeedbackConsumer<T> consumer, Executor executor) {
+ ConsumerTask<T> consumerTask = new ConsumerTask<>(executor, consumer, queue);
+ if (!this.consumerRef.compareAndSet(null, consumerTask)) {
+ throw new IllegalStateException(
+ "There can be only a single consumer in a FeedbackChannel.");
+ }
+ consumerTask.scheduleDrainAll();
+ }
+
+ @Override
+ public void close() {
+ consumerRef.getAndSet(null);
+ SpillableFeedbackChannelBroker broker = SpillableFeedbackChannelBroker.get();
+ broker.removeChannel(key);
+ IOUtils.closeQuietly(queue);
+ }
+
+ public boolean isInitialized() {
+ return this.queue != null;
+ }
+
+ private static final class ConsumerTask<T> implements Runnable {
+ private final Executor executor;
+ private final FeedbackConsumer<T> consumer;
+ private final MpscQueue<T> queue;
+
+ ConsumerTask(Executor executor, FeedbackConsumer<T> consumer, MpscQueue<T> queue) {
+ this.executor = Objects.requireNonNull(executor);
+ this.consumer = Objects.requireNonNull(consumer);
+ this.queue = Objects.requireNonNull(queue);
+ }
+
+ void scheduleDrainAll() {
+ executor.execute(this);
+ }
+
+ @Override
+ public void run() {
+ final MutableObjectIterator<T> buffer = queue.drainAll();
+ try {
+ T element;
+ while ((element = buffer.next()) != null) {
+ consumer.processFeedback(element);
+ }
+ queue.resetStandBy();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackChannelBroker.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackChannelBroker.java
new file mode 100644
index 0000000..705eea2
--- /dev/null
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackChannelBroker.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.feedback;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.statefun.flink.core.feedback.SubtaskFeedbackKey;
+import org.apache.flink.util.function.ThrowingConsumer;
+
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * HandOffChannelBroker.
+ *
+ * <p>It is used together with the co-location constrain so that two tasks can access the same
+ * "hand-off" channel, and communicate directly (not via the network stack) by simply passing
+ * references in one direction.
+ *
+ * <p>To obtain a feedback channel one must first obtain an {@link SubtaskFeedbackKey} and simply
+ * call {@link #get()}. A channel is removed from this broker on a call to {@link
+ * SpillableFeedbackChannel#close()}.
+ */
+@Internal
+public final class SpillableFeedbackChannelBroker {
+
+ private static final SpillableFeedbackChannelBroker INSTANCE =
+ new SpillableFeedbackChannelBroker();
+
+ private final ConcurrentHashMap<SubtaskFeedbackKey<?>, SpillableFeedbackChannel<?>> channels =
+ new ConcurrentHashMap<>();
+
+ public static SpillableFeedbackChannelBroker get() {
+ return INSTANCE;
+ }
+
+ @SuppressWarnings({"unchecked"})
+ public <V> SpillableFeedbackChannel<V> getChannel(SubtaskFeedbackKey<V> key) {
+ Objects.requireNonNull(key);
+
+ SpillableFeedbackChannel<?> channel =
+ channels.computeIfAbsent(key, SpillableFeedbackChannelBroker::newChannel);
+
+ return (SpillableFeedbackChannel<V>) channel;
+ }
+
+ @SuppressWarnings({"unchecked"})
+ public <V> SpillableFeedbackChannel<V> getChannel(
+ SubtaskFeedbackKey<V> key,
+ ThrowingConsumer<SpillableFeedbackChannel, MemoryAllocationException> initializer)
+ throws MemoryAllocationException {
+ Objects.requireNonNull(key);
+
+ SpillableFeedbackChannel<?> channel =
+ channels.computeIfAbsent(key, SpillableFeedbackChannelBroker::newChannel);
+
+ if (!channel.isInitialized() && initializer != null) {
+ initializer.accept(channel);
+ }
+
+ return (SpillableFeedbackChannel<V>) channel;
+ }
+
+ @SuppressWarnings("resource")
+ void removeChannel(SubtaskFeedbackKey<?> key) {
+ channels.remove(key);
+ }
+
+ private static <V> SpillableFeedbackChannel<V> newChannel(SubtaskFeedbackKey<V> key) {
+ return new SpillableFeedbackChannel<>(key);
+ }
+}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackQueue.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackQueue.java
new file mode 100644
index 0000000..82f3528
--- /dev/null
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/feedback/SpillableFeedbackQueue.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.feedback;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.disk.InputViewIterator;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.iterative.io.SerializedUpdateBuffer;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.util.MutableObjectIterator;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * * A queue that can spill the items to disks automatically when the memory buffer is full.
+ *
+ * @param <T> The element type.
+ */
+@Internal
+final class SpillableFeedbackQueue<T> {
+ private final DataOutputSerializer output = new DataOutputSerializer(256);
+ private final TypeSerializer<T> serializer;
+ private final MemoryManager memoryManager;
+ private final List<MemorySegment> freeMemory;
+ private final SerializedUpdateBuffer buffer;
+ private long size = 0L;
+
+ SpillableFeedbackQueue(
+ IOManager ioManager,
+ MemoryManager memoryManager,
+ TypeSerializer<T> serializer,
+ long inMemoryBufferSize)
+ throws MemoryAllocationException {
+ this.serializer = Objects.requireNonNull(serializer);
+ this.memoryManager = Objects.requireNonNull(memoryManager);
+
+ int numPages = (int) (inMemoryBufferSize / memoryManager.getPageSize());
+ this.freeMemory = memoryManager.allocatePages(this, numPages);
+ this.buffer =
+ new SerializedUpdateBuffer(freeMemory, memoryManager.getPageSize(), ioManager);
+ }
+
+ void add(T item) {
+ try {
+ output.clear();
+ serializer.serialize(item, output);
+ buffer.write(output.getSharedBuffer(), 0, output.length());
+ size++;
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ MutableObjectIterator<T> iterate() {
+ try {
+ DataInputView input = buffer.switchBuffers();
+ return new InputViewIterator<>(input, this.serializer);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ long size() {
+ return size;
+ }
+
+ public void reset() {
+ this.size = 0;
+ }
+
+ void release() {
+ output.clear();
+ List<MemorySegment> toRelease = buffer.close();
+ toRelease.addAll(freeMemory);
+ freeMemory.clear();
+ memoryManager.release(toRelease);
+ }
+}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
index 107a233..0648160 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
@@ -116,7 +116,7 @@
new StreamRecord<>(
IterationRecord.newEpochWatermark(
globallyAlignedEvent.isTerminated()
- ? Integer.MAX_VALUE
+ ? IterationRecord.END_EPOCH_WATERMARK
: globallyAlignedEvent.getEpoch(),
senderId),
0);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
index 0cd93e7..bbbd7eb 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
@@ -26,7 +26,7 @@
/**
* Processor used after we received terminated globally aligned event from the coordinator, but
- * before we received the (Integer.MAX_VALUE + 1) from the feedback channel again.
+ * before we received the Integer.MAX_VALUE from the feedback channel again.
*/
public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecordProcessor {
@@ -51,7 +51,7 @@
@Override
public boolean processFeedbackElement(StreamRecord<IterationRecord<?>> record) {
if (record.getValue().getType() == IterationRecord.Type.EPOCH_WATERMARK) {
- return record.getValue().getEpoch() == Integer.MAX_VALUE + 1;
+ return record.getValue().getEpoch() == Integer.MAX_VALUE;
}
return false;
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
index d672876..60fda47 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
@@ -213,7 +213,7 @@
// the onEpochWatermarkIncrement must be from 0 and increment by 1 each time, except
// for the last round.
try {
- if (epochWatermark < Integer.MAX_VALUE) {
+ if (epochWatermark < IterationRecord.END_EPOCH_WATERMARK) {
S wrappedOperator = wrappedOperators.remove(epochWatermark);
if (wrappedOperator != null) {
closeStreamOperator(wrappedOperator, epochWatermark, epochWatermark);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
index 33ddee8..766ef4b 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
@@ -19,6 +19,7 @@
package org.apache.flink.iteration.progresstrack;
import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.iteration.IterationRecord;
import java.io.IOException;
import java.util.ArrayList;
@@ -109,7 +110,7 @@
public void finish() {
for (int i = 0; i < numberOfChannels; ++i) {
- allChannelsLowerBound.updateValue(i, Integer.MAX_VALUE);
+ allChannelsLowerBound.updateValue(i, IterationRecord.END_EPOCH_WATERMARK);
}
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/utils/DataStreamUtils.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/utils/DataStreamUtils.java
new file mode 100644
index 0000000..05c4699
--- /dev/null
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/utils/DataStreamUtils.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.utils;
+
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.TableException;
+
+import java.util.Optional;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+
+ /**
+ * Sets {Transformation#declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase, int)}
+ * using the given bytes for {@link ManagedMemoryUseCase#OPERATOR}.
+ *
+ * <p>This method is in reference to Flink's ExecNodeUtil.setManagedMemoryWeight. The provided
+ * bytes should be in the same scale as existing usage in Flink, for example,
+ * StreamExecWindowAggregate.WINDOW_AGG_MEMORY_RATIO.
+ */
+ public static <T> void setManagedMemoryWeight(DataStream<T> dataStream, long memoryBytes) {
+ if (memoryBytes > 0) {
+ final int weightInMebibyte = Math.max(1, (int) (memoryBytes >> 20));
+ final Optional<Integer> previousWeight =
+ dataStream
+ .getTransformation()
+ .declareManagedMemoryUseCaseAtOperatorScope(
+ ManagedMemoryUseCase.OPERATOR, weightInMebibyte);
+ if (previousWeight.isPresent()) {
+ throw new TableException(
+ "Managed memory weight has been set, this should not happen.");
+ }
+ }
+ }
+}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index fc9bd6b..f1a320c 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -19,12 +19,15 @@
package org.apache.flink.iteration.operator;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.iteration.operator.event.TerminatingOnInitializeEvent;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannel;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannelBroker;
import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
@@ -35,11 +38,10 @@
import org.apache.flink.runtime.io.network.api.EndOfData;
import org.apache.flink.runtime.io.network.api.StopMode;
import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -171,7 +173,7 @@
putFeedbackRecords(
iterationId,
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE + 1, "tail"),
+ Integer.MAX_VALUE, "tail"),
null);
return null;
@@ -204,7 +206,7 @@
new StreamRecord<>(IterationRecord.newRecord(4, 1), 4),
new StreamRecord<>(
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE,
+ IterationRecord.END_EPOCH_WATERMARK,
OperatorUtils.getUniqueSenderId(operatorId, 0)),
0)),
new ArrayList<>(harness.getOutput()));
@@ -353,7 +355,6 @@
.getLastJobManagerTaskStateSnapshot();
});
assertNotNull(taskStateSnapshot);
- cleanupFeedbackChannel(iterationId);
createHarnessAndRun(
iterationId,
operatorId,
@@ -367,6 +368,7 @@
Collections.emptyMap(),
-1,
-1);
+ harness.processAll();
return null;
});
}
@@ -409,7 +411,6 @@
.getLastJobManagerTaskStateSnapshot();
});
assertNotNull(taskStateSnapshot);
- cleanupFeedbackChannel(iterationId);
createHarnessAndRun(
iterationId,
operatorId,
@@ -484,7 +485,6 @@
.getLastJobManagerTaskStateSnapshot();
});
assertNotNull(taskStateSnapshot);
- cleanupFeedbackChannel(iterationId);
createHarnessAndRun(
iterationId,
operatorId,
@@ -498,6 +498,7 @@
Collections.singletonMap(1, 1L),
1,
0);
+ harness.processAll();
return null;
});
}
@@ -554,7 +555,6 @@
.getLastJobManagerTaskStateSnapshot();
});
assertNotNull(taskStateSnapshot);
- cleanupFeedbackChannel(iterationId);
createHarnessAndRun(
iterationId,
operatorId,
@@ -577,6 +577,7 @@
},
5,
4);
+ harness.processAll();
return null;
});
}
@@ -630,7 +631,6 @@
});
assertNotNull(taskStateSnapshot);
- cleanupFeedbackChannel(iterationId);
createHarnessAndRun(
iterationId,
operatorId,
@@ -644,6 +644,7 @@
Collections.emptyMap(),
5,
4);
+ harness.processAll();
return null;
});
}
@@ -685,7 +686,6 @@
.getLastJobManagerTaskStateSnapshot();
});
assertNotNull(taskStateSnapshot);
- cleanupFeedbackChannel(iterationId);
createHarnessAndRun(
iterationId,
operatorId,
@@ -702,7 +702,7 @@
putFeedbackRecords(
iterationId,
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE + 1, "tail"),
+ IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "tail"),
null);
harness.processEvent(new EndOfData(StopMode.DRAIN));
harness.finishProcessing();
@@ -813,6 +813,12 @@
OneInputStreamTask::new,
new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
.addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .modifyStreamConfig(
+ // Set the fraction to 0.5 to make sure there are enough pages
+ // for each feedback queue.
+ streamConfig ->
+ streamConfig.setManagedMemoryFractionOperatorOfUseCase(
+ ManagedMemoryUseCase.OPERATOR, 0.5))
.setTaskStateSnapshot(
1, snapshot == null ? new TaskStateSnapshot() : snapshot)
.setupOutputForSingletonOperatorChain(
@@ -828,6 +834,7 @@
return runnable.apply(harness);
} finally {
RecordingHeadOperatorFactory.latestHeadOperator.close();
+ RecordingHeadOperatorFactory.latestHeadOperator.getFeedbackChannel().close();
}
}
}
@@ -850,13 +857,15 @@
}
private static void putFeedbackRecords(
- IterationID iterationId, IterationRecord<?> record, @Nullable Long timestamp) {
- FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
- FeedbackChannelBroker.get()
+ IterationID iterationId, IterationRecord<?> record, @Nullable Long timestamp)
+ throws MemoryAllocationException {
+ SpillableFeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
+ SpillableFeedbackChannelBroker.get()
.getChannel(
OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
iterationId, 0)
- .withSubTaskIndex(0, 0));
+ .withSubTaskIndex(0, 0),
+ null);
feedbackChannel.put(
timestamp == null
? new StreamRecord<>(record)
@@ -889,20 +898,6 @@
}
}
- /**
- * We have to manually cleanup the feedback channel due to not be able to set the attempt
- * number.
- */
- private static void cleanupFeedbackChannel(IterationID iterationId) {
- FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
- FeedbackChannelBroker.get()
- .getChannel(
- OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
- iterationId, 0)
- .withSubTaskIndex(0, 0));
- feedbackChannel.close();
- }
-
private static class RecordingOperatorEventGateway implements OperatorEventGateway {
final BlockingQueue<OperatorEvent> operatorEvents = new LinkedBlockingQueue<>();
@@ -937,7 +932,6 @@
@Override
public <T extends StreamOperator<IterationRecord<?>>> T createStreamOperator(
StreamOperatorParameters<IterationRecord<?>> streamOperatorParameters) {
-
latestHeadOperator = super.createStreamOperator(streamOperatorParameters);
return (T) latestHeadOperator;
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/TailOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/TailOperatorTest.java
index 45ddab2..ae8a2fd 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/TailOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/TailOperatorTest.java
@@ -18,17 +18,24 @@
package org.apache.flink.iteration.operator;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannel;
+import org.apache.flink.iteration.operator.feedback.SpillableFeedbackChannelBroker;
+import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.runtime.testutils.DirectScheduledExecutorService;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
-import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.util.TestLogger;
import org.junit.Test;
+import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -42,16 +49,21 @@
public void testIncrementRoundWithoutObjectReuse() throws Exception {
IterationID iterationId = new IterationID();
+ IterationRecordTypeInfo typeInfo =
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO);
+ TypeSerializer serializer = typeInfo.createSerializer(new ExecutionConfig());
OneInputStreamOperatorTestHarness<IterationRecord<?>, Void> testHarness =
- new OneInputStreamOperatorTestHarness<>(new TailOperator(iterationId, 0));
+ new OneInputStreamOperatorTestHarness<>(
+ new TailOperator(iterationId, 0), serializer);
testHarness.open();
+ SpillableFeedbackChannel channel =
+ initializeFeedbackChannel(testHarness.getOperator(), iterationId, 0, 0, 0);
testHarness.processElement(IterationRecord.newRecord(1, 1), 2);
testHarness.processElement(IterationRecord.newRecord(2, 1), 3);
testHarness.processElement(IterationRecord.newEpochWatermark(2, "sender1"), 4);
- List<StreamRecord<IterationRecord<?>>> iterationRecords =
- getFeedbackRecords(iterationId, 0, 0, 0);
+ List<StreamRecord<IterationRecord<?>>> iterationRecords = getFeedbackRecords(channel);
assertEquals(
Arrays.asList(
new StreamRecord<>(IterationRecord.newRecord(1, 2), 2),
@@ -64,10 +76,17 @@
public void testIncrementRoundWithObjectReuse() throws Exception {
IterationID iterationId = new IterationID();
+ IterationRecordTypeInfo typeInfo =
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO);
+ TypeSerializer serializer = typeInfo.createSerializer(new ExecutionConfig());
+
OneInputStreamOperatorTestHarness<IterationRecord<?>, Void> testHarness =
- new OneInputStreamOperatorTestHarness<>(new TailOperator(iterationId, 0));
+ new OneInputStreamOperatorTestHarness<>(
+ new TailOperator(iterationId, 0), serializer);
testHarness.getExecutionConfig().enableObjectReuse();
testHarness.open();
+ SpillableFeedbackChannel channel =
+ initializeFeedbackChannel(testHarness.getOperator(), iterationId, 0, 0, 0);
IterationRecord<Integer> reuse = IterationRecord.newRecord(1, 1);
testHarness.processElement(reuse, 2);
@@ -80,8 +99,7 @@
reuse.setSender("sender1");
testHarness.processElement(reuse, 4);
- List<StreamRecord<IterationRecord<?>>> iterationRecords =
- getFeedbackRecords(iterationId, 0, 0, 0);
+ List<StreamRecord<IterationRecord<?>>> iterationRecords = getFeedbackRecords(channel);
assertEquals(
Arrays.asList(
new StreamRecord<>(IterationRecord.newRecord(1, 2), 2),
@@ -90,17 +108,60 @@
iterationRecords);
}
+ @Test
+ public void testSpillFeedbackToDisk() throws Exception {
+ IterationID iterationId = new IterationID();
+
+ IterationRecordTypeInfo typeInfo =
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO);
+ TypeSerializer serializer = typeInfo.createSerializer(new ExecutionConfig());
+ OneInputStreamOperatorTestHarness<IterationRecord<?>, Void> testHarness =
+ new OneInputStreamOperatorTestHarness<>(
+ new TailOperator(iterationId, 0), serializer);
+ testHarness.open();
+ initializeFeedbackChannel(testHarness.getOperator(), iterationId, 0, 0, 0);
+
+ File spillPath =
+ new File(
+ testHarness
+ .getOperator()
+ .getContainingTask()
+ .getEnvironment()
+ .getIOManager()
+ .getSpillingDirectoriesPaths()[0]);
+ assertEquals(0, spillPath.listFiles().length);
+
+ for (int i = 0; i < 10; i++) {
+ testHarness.processElement(IterationRecord.newRecord(i, 1), i);
+ }
+ assertEquals(0, spillPath.listFiles().length);
+
+ for (int i = 0; i < (1 << 16); i++) {
+ testHarness.processElement(IterationRecord.newRecord(i, 1), i);
+ }
+ assertEquals(1, spillPath.listFiles().length);
+ }
+
static List<StreamRecord<IterationRecord<?>>> getFeedbackRecords(
- IterationID iterationId, int feedbackIndex, int subtaskIndex, int attemptNumber) {
- FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
- FeedbackChannelBroker.get()
- .getChannel(
- OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
- iterationId, feedbackIndex)
- .withSubTaskIndex(subtaskIndex, attemptNumber));
+ SpillableFeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel) {
List<StreamRecord<IterationRecord<?>>> iterationRecords = new ArrayList<>();
OperatorUtils.registerFeedbackConsumer(
feedbackChannel, iterationRecords::add, new DirectScheduledExecutorService());
return iterationRecords;
}
+
+ static SpillableFeedbackChannel<StreamRecord<IterationRecord<?>>> initializeFeedbackChannel(
+ AbstractStreamOperator operator,
+ IterationID iterationId,
+ int feedbackIndex,
+ int subtaskIndex,
+ int attemptNumber)
+ throws MemoryAllocationException {
+ return SpillableFeedbackChannelBroker.get()
+ .getChannel(
+ OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
+ iterationId, feedbackIndex)
+ .withSubTaskIndex(subtaskIndex, attemptNumber),
+ channel -> OperatorUtils.initializeFeedbackChannel(channel, operator));
+ }
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
index e07b5ae..76b4baa 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
@@ -91,15 +91,18 @@
new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-2")), 2);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-0")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-0")),
0);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-1")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-1")),
1);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-2")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-2")),
2);
// Checks the output
@@ -113,7 +116,7 @@
5, OperatorUtils.getUniqueSenderId(operatorId, 0))),
new StreamRecord<>(
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE,
+ IterationRecord.END_EPOCH_WATERMARK,
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
new ArrayList<>(harness.getOutput()));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index 7b9d402..81eb52b 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -88,7 +88,8 @@
new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one")));
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")));
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one")));
// Checks the output
assertEquals(
@@ -100,7 +101,7 @@
5, OperatorUtils.getUniqueSenderId(operatorId, 0))),
new StreamRecord<>(
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE,
+ IterationRecord.END_EPOCH_WATERMARK,
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
new ArrayList<>(harness.getOutput()));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
index b8aed71..7c4d791 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
@@ -86,11 +86,13 @@
new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-1")), 1);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-0")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-0")),
0);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-1")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-1")),
1);
// Checks the output
@@ -103,7 +105,7 @@
5, OperatorUtils.getUniqueSenderId(operatorId, 0))),
new StreamRecord<>(
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE,
+ IterationRecord.END_EPOCH_WATERMARK,
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
new ArrayList<>(harness.getOutput()));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
index c58ded8..3c61f51 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
@@ -119,15 +119,18 @@
new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one-2")), 2);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-0")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-0")),
0);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-1")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-1")),
1);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-2")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one-2")),
2);
// Checks the output
@@ -138,7 +141,7 @@
1, OperatorUtils.getUniqueSenderId(operatorId, 0))),
new StreamRecord<>(
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE,
+ IterationRecord.END_EPOCH_WATERMARK,
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
new ArrayList<>(harness.getOutput()));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
index 384ef91..cd6b00a 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
@@ -129,7 +129,8 @@
new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one")));
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")));
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one")));
// Checks the output
assertEquals(
@@ -139,7 +140,7 @@
1, OperatorUtils.getUniqueSenderId(operatorId, 0))),
new StreamRecord<>(
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE,
+ IterationRecord.END_EPOCH_WATERMARK,
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
new ArrayList<>(harness.getOutput()));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
index 26a2797..1d41bbd 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
@@ -114,11 +114,13 @@
new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one")), 1);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one")),
0);
harness.processElement(
new StreamRecord<>(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")),
+ IterationRecord.newEpochWatermark(
+ IterationRecord.END_EPOCH_WATERMARK, "only-one")),
1);
// Checks the output
@@ -129,7 +131,7 @@
1, OperatorUtils.getUniqueSenderId(operatorId, 0))),
new StreamRecord<>(
IterationRecord.newEpochWatermark(
- Integer.MAX_VALUE,
+ IterationRecord.END_EPOCH_WATERMARK,
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
new ArrayList<>(harness.getOutput()));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/typeinfo/IterationRecordSerializerTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/typeinfo/IterationRecordSerializerTest.java
index 275db2b..ae9d249 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/typeinfo/IterationRecordSerializerTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/typeinfo/IterationRecordSerializerTest.java
@@ -53,7 +53,7 @@
testSerializeAndDeserialize(
IterationRecord.newEpochWatermark(10, "sender1"), VoidSerializer.INSTANCE);
testSerializeAndDeserialize(
- IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "sender1"),
+ IterationRecord.newEpochWatermark(IterationRecord.END_EPOCH_WATERMARK, "sender1"),
VoidSerializer.INSTANCE);
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 8aeb6b6..788caee 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -74,6 +74,8 @@
import java.util.Map;
import java.util.Objects;
+import static org.apache.flink.iteration.utils.DataStreamUtils.setManagedMemoryWeight;
+
/**
* An Estimator which implements the k-means clustering algorithm.
*
@@ -163,7 +165,7 @@
DenseVectorTypeInfo.INSTANCE)),
new CentroidsUpdateAccumulator(distanceMeasure));
- DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints, 100);
+ setManagedMemoryWeight(centroidIdAndPoints, 100);
int parallelism = centroidIdAndPoints.getParallelism();
DataStream<KMeansModelData> newModelData =