blob: 9854df21360b7011fe5d8a09b3672735a7fe43b4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.checkpointing;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.client.JobStatusMessage;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.runtime.testutils.TestingUtils;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.util.Collector;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.FutureUtils;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.File;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static org.apache.flink.runtime.testutils.CommonTestUtils.waitForAllTaskRunning;
import static org.apache.flink.test.util.TestUtils.submitJobAndWaitForResult;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/** Test savepoint rescaling. */
@RunWith(Parameterized.class)
public class RescalingITCase extends TestLogger {
private static final int numTaskManagers = 2;
private static final int slotsPerTaskManager = 2;
private static final int numSlots = numTaskManagers * slotsPerTaskManager;
@Parameterized.Parameters(name = "backend = {0}, buffersPerChannel = {1}")
public static Collection<Object[]> data() {
return Arrays.asList(
new Object[][] {
{"filesystem", 2}, {"rocksdb", 0}, {"filesystem", 0}, {"rocksdb", 2}
});
}
public RescalingITCase(String backend, int buffersPerChannel) {
this.backend = backend;
this.buffersPerChannel = buffersPerChannel;
}
private final String backend;
private final int buffersPerChannel;
private String currentBackend = null;
enum OperatorCheckpointMethod {
NON_PARTITIONED,
CHECKPOINTED_FUNCTION,
CHECKPOINTED_FUNCTION_BROADCAST,
LIST_CHECKPOINTED
}
private static MiniClusterWithClientResource cluster;
@ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
@Before
public void setup() throws Exception {
// detect parameter change
if (currentBackend != backend) {
shutDownExistingCluster();
currentBackend = backend;
Configuration config = new Configuration();
final File checkpointDir = temporaryFolder.newFolder();
final File savepointDir = temporaryFolder.newFolder();
config.setString(StateBackendOptions.STATE_BACKEND, currentBackend);
config.setString(
CheckpointingOptions.CHECKPOINTS_DIRECTORY, checkpointDir.toURI().toString());
config.setString(
CheckpointingOptions.SAVEPOINT_DIRECTORY, savepointDir.toURI().toString());
config.setInteger(
NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_PER_CHANNEL, buffersPerChannel);
cluster =
new MiniClusterWithClientResource(
new MiniClusterResourceConfiguration.Builder()
.setConfiguration(config)
.setNumberTaskManagers(numTaskManagers)
.setNumberSlotsPerTaskManager(numSlots)
.build());
cluster.before();
}
}
@AfterClass
public static void shutDownExistingCluster() {
if (cluster != null) {
cluster.after();
cluster = null;
}
}
@Test
public void testSavepointRescalingInKeyedState() throws Exception {
testSavepointRescalingKeyedState(false, false);
}
@Test
public void testSavepointRescalingOutKeyedState() throws Exception {
testSavepointRescalingKeyedState(true, false);
}
@Test
public void testSavepointRescalingInKeyedStateDerivedMaxParallelism() throws Exception {
testSavepointRescalingKeyedState(false, true);
}
@Test
public void testSavepointRescalingOutKeyedStateDerivedMaxParallelism() throws Exception {
testSavepointRescalingKeyedState(true, true);
}
/**
* Tests that a job with purely keyed state can be restarted from a savepoint with a different
* parallelism.
*/
public void testSavepointRescalingKeyedState(boolean scaleOut, boolean deriveMaxParallelism)
throws Exception {
final int numberKeys = 42;
final int numberElements = 1000;
final int numberElements2 = 500;
final int parallelism = scaleOut ? numSlots / 2 : numSlots;
final int parallelism2 = scaleOut ? numSlots : numSlots / 2;
final int maxParallelism = 13;
Duration timeout = Duration.ofMinutes(3);
Deadline deadline = Deadline.now().plus(timeout);
ClusterClient<?> client = cluster.getClusterClient();
try {
JobGraph jobGraph =
createJobGraphWithKeyedState(
parallelism, maxParallelism, numberKeys, numberElements, false, 100);
final JobID jobID = jobGraph.getJobID();
client.submitJob(jobGraph).get();
// wait til the sources have emitted numberElements for each key and completed a
// checkpoint
assertTrue(
SubtaskIndexFlatMapper.workCompletedLatch.await(
deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS));
// verify the current state
Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult.add(
Tuple2.of(
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
maxParallelism, parallelism, keyGroupIndex),
numberElements * key));
}
assertEquals(expectedResult, actualResult);
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
waitForAllTaskRunning(cluster.getMiniCluster(), jobGraph.getJobID());
CompletableFuture<String> savepointPathFuture = client.triggerSavepoint(jobID, null);
final String savepointPath =
savepointPathFuture.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
client.cancel(jobID).get();
while (!getRunningJobs(client).isEmpty()) {
Thread.sleep(50);
}
int restoreMaxParallelism =
deriveMaxParallelism ? JobVertex.MAX_PARALLELISM_DEFAULT : maxParallelism;
JobGraph scaledJobGraph =
createJobGraphWithKeyedState(
parallelism2,
restoreMaxParallelism,
numberKeys,
numberElements2,
true,
100);
scaledJobGraph.setSavepointRestoreSettings(
SavepointRestoreSettings.forPath(savepointPath));
submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader());
Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult2.add(
Tuple2.of(
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
maxParallelism, parallelism2, keyGroupIndex),
key * (numberElements + numberElements2)));
}
assertEquals(expectedResult2, actualResult2);
} finally {
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
}
}
/**
* Tests that a job cannot be restarted from a savepoint with a different parallelism if the
* rescaled operator has non-partitioned state.
*
* @throws Exception
*/
@Test
public void testSavepointRescalingNonPartitionedStateCausesException() throws Exception {
final int parallelism = numSlots / 2;
final int parallelism2 = numSlots;
final int maxParallelism = 13;
Duration timeout = Duration.ofMinutes(3);
Deadline deadline = Deadline.now().plus(timeout);
ClusterClient<?> client = cluster.getClusterClient();
try {
JobGraph jobGraph =
createJobGraphWithOperatorState(
parallelism, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
// make sure the job does not finish before we take the savepoint
StateSourceBase.canFinishLatch = new CountDownLatch(1);
final JobID jobID = jobGraph.getJobID();
client.submitJob(jobGraph).get();
// wait until the operator is started
waitForAllTaskRunning(cluster.getMiniCluster(), jobGraph.getJobID());
// wait until the operator handles some data
StateSourceBase.workStartedLatch.await();
CompletableFuture<String> savepointPathFuture = client.triggerSavepoint(jobID, null);
final String savepointPath =
savepointPathFuture.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
// we took a savepoint, the job can finish now
StateSourceBase.canFinishLatch.countDown();
client.cancel(jobID).get();
while (!getRunningJobs(client).isEmpty()) {
Thread.sleep(50);
}
// job successfully removed
JobGraph scaledJobGraph =
createJobGraphWithOperatorState(
parallelism2, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
scaledJobGraph.setSavepointRestoreSettings(
SavepointRestoreSettings.forPath(savepointPath));
submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader());
} catch (JobExecutionException exception) {
if (exception.getCause() instanceof IllegalStateException) {
// we expect a IllegalStateException wrapped
// in a JobExecutionException, because the job containing non-partitioned state
// is being rescaled
} else {
throw exception;
}
}
}
/**
* Tests that a job with non partitioned state can be restarted from a savepoint with a
* different parallelism if the operator with non-partitioned state are not rescaled.
*
* @throws Exception
*/
@Test
public void testSavepointRescalingWithKeyedAndNonPartitionedState() throws Exception {
int numberKeys = 42;
int numberElements = 1000;
int numberElements2 = 500;
int parallelism = numSlots / 2;
int parallelism2 = numSlots;
int maxParallelism = 13;
Duration timeout = Duration.ofMinutes(3);
Deadline deadline = Deadline.now().plus(timeout);
ClusterClient<?> client = cluster.getClusterClient();
try {
JobGraph jobGraph =
createJobGraphWithKeyedAndNonPartitionedOperatorState(
parallelism,
maxParallelism,
parallelism,
numberKeys,
numberElements,
false,
100);
final JobID jobID = jobGraph.getJobID();
// make sure the job does not finish before we take the savepoint
StateSourceBase.canFinishLatch = new CountDownLatch(1);
client.submitJob(jobGraph).get();
// wait til the sources have emitted numberElements for each key and completed a
// checkpoint
assertTrue(
SubtaskIndexFlatMapper.workCompletedLatch.await(
deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS));
// verify the current state
Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult.add(
Tuple2.of(
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
maxParallelism, parallelism, keyGroupIndex),
numberElements * key));
}
assertEquals(expectedResult, actualResult);
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
waitForAllTaskRunning(cluster.getMiniCluster(), jobGraph.getJobID());
CompletableFuture<String> savepointPathFuture = client.triggerSavepoint(jobID, null);
final String savepointPath =
savepointPathFuture.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
// we took a savepoint, the job can finish now
StateSourceBase.canFinishLatch.countDown();
client.cancel(jobID).get();
while (!getRunningJobs(client).isEmpty()) {
Thread.sleep(50);
}
JobGraph scaledJobGraph =
createJobGraphWithKeyedAndNonPartitionedOperatorState(
parallelism2,
maxParallelism,
parallelism,
numberKeys,
numberElements + numberElements2,
true,
100);
scaledJobGraph.setSavepointRestoreSettings(
SavepointRestoreSettings.forPath(savepointPath));
submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader());
Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
expectedResult2.add(
Tuple2.of(
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
maxParallelism, parallelism2, keyGroupIndex),
key * (numberElements + numberElements2)));
}
assertEquals(expectedResult2, actualResult2);
} finally {
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
}
}
@Test
public void testSavepointRescalingInPartitionedOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(
false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
}
@Test
public void testSavepointRescalingOutPartitionedOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(
true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
}
@Test
public void testSavepointRescalingInBroadcastOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(
false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
}
@Test
public void testSavepointRescalingOutBroadcastOperatorState() throws Exception {
testSavepointRescalingPartitionedOperatorState(
true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
}
@Test
public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception {
testSavepointRescalingPartitionedOperatorState(
false, OperatorCheckpointMethod.LIST_CHECKPOINTED);
}
@Test
public void testSavepointRescalingOutPartitionedOperatorStateList() throws Exception {
testSavepointRescalingPartitionedOperatorState(
true, OperatorCheckpointMethod.LIST_CHECKPOINTED);
}
/**
* Tests rescaling of partitioned operator state. More specific, we test the mechanism with
* {@link ListCheckpointed} as it subsumes {@link
* org.apache.flink.streaming.api.checkpoint.CheckpointedFunction}.
*/
public void testSavepointRescalingPartitionedOperatorState(
boolean scaleOut, OperatorCheckpointMethod checkpointMethod) throws Exception {
final int parallelism = scaleOut ? numSlots : numSlots / 2;
final int parallelism2 = scaleOut ? numSlots / 2 : numSlots;
final int maxParallelism = 13;
Duration timeout = Duration.ofMinutes(3);
Deadline deadline = Deadline.now().plus(timeout);
ClusterClient<?> client = cluster.getClusterClient();
int counterSize = Math.max(parallelism, parallelism2);
if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION
|| checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
PartitionedStateSource.checkCorrectSnapshot = new int[counterSize];
PartitionedStateSource.checkCorrectRestore = new int[counterSize];
} else {
PartitionedStateSourceListCheckpointed.checkCorrectSnapshot = new int[counterSize];
PartitionedStateSourceListCheckpointed.checkCorrectRestore = new int[counterSize];
}
try {
JobGraph jobGraph =
createJobGraphWithOperatorState(parallelism, maxParallelism, checkpointMethod);
// make sure the job does not finish before we take the savepoint
StateSourceBase.canFinishLatch = new CountDownLatch(1);
final JobID jobID = jobGraph.getJobID();
client.submitJob(jobGraph).get();
// wait until the operator is started
waitForAllTaskRunning(cluster.getMiniCluster(), jobGraph.getJobID());
// wait until the operator handles some data
StateSourceBase.workStartedLatch.await();
CompletableFuture<String> savepointPathFuture =
FutureUtils.retryWithDelay(
() -> client.triggerSavepoint(jobID, null),
(int) deadline.timeLeft().getSeconds() / 10,
Time.seconds(10),
(throwable) -> true,
TestingUtils.defaultScheduledExecutor());
final String savepointPath =
savepointPathFuture.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
// we took a savepoint, the job can finish now
StateSourceBase.canFinishLatch.countDown();
client.cancel(jobID).get();
while (!getRunningJobs(client).isEmpty()) {
Thread.sleep(50);
}
JobGraph scaledJobGraph =
createJobGraphWithOperatorState(parallelism2, maxParallelism, checkpointMethod);
scaledJobGraph.setSavepointRestoreSettings(
SavepointRestoreSettings.forPath(savepointPath));
submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader());
int sumExp = 0;
int sumAct = 0;
if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
for (int c : PartitionedStateSource.checkCorrectSnapshot) {
sumExp += c;
}
for (int c : PartitionedStateSource.checkCorrectRestore) {
sumAct += c;
}
} else if (checkpointMethod
== OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
for (int c : PartitionedStateSource.checkCorrectSnapshot) {
sumExp += c;
}
for (int c : PartitionedStateSource.checkCorrectRestore) {
sumAct += c;
}
sumExp *= parallelism2;
} else {
for (int c : PartitionedStateSourceListCheckpointed.checkCorrectSnapshot) {
sumExp += c;
}
for (int c : PartitionedStateSourceListCheckpointed.checkCorrectRestore) {
sumAct += c;
}
}
assertEquals(sumExp, sumAct);
} finally {
}
}
// ------------------------------------------------------------------------------------------------------------------
private static JobGraph createJobGraphWithOperatorState(
int parallelism, int maxParallelism, OperatorCheckpointMethod checkpointMethod) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
env.getConfig().setMaxParallelism(maxParallelism);
env.enableCheckpointing(Long.MAX_VALUE);
env.setRestartStrategy(RestartStrategies.noRestart());
StateSourceBase.workStartedLatch = new CountDownLatch(parallelism);
SourceFunction<Integer> src;
switch (checkpointMethod) {
case CHECKPOINTED_FUNCTION:
src = new PartitionedStateSource(false);
break;
case CHECKPOINTED_FUNCTION_BROADCAST:
src = new PartitionedStateSource(true);
break;
case LIST_CHECKPOINTED:
src = new PartitionedStateSourceListCheckpointed();
break;
case NON_PARTITIONED:
src = new NonPartitionedStateSource();
break;
default:
throw new IllegalArgumentException();
}
DataStream<Integer> input = env.addSource(src);
input.addSink(new DiscardingSink<Integer>());
return env.getStreamGraph().getJobGraph();
}
private static JobGraph createJobGraphWithKeyedState(
int parallelism,
int maxParallelism,
int numberKeys,
int numberElements,
boolean terminateAfterEmission,
int checkpointingInterval) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
if (0 < maxParallelism) {
env.getConfig().setMaxParallelism(maxParallelism);
}
env.enableCheckpointing(checkpointingInterval);
env.setRestartStrategy(RestartStrategies.noRestart());
env.getConfig().setUseSnapshotCompression(true);
DataStream<Integer> input =
env.addSource(
new SubtaskIndexSource(
numberKeys, numberElements, terminateAfterEmission))
.keyBy(
new KeySelector<Integer, Integer>() {
private static final long serialVersionUID =
-7952298871120320940L;
@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
});
SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
DataStream<Tuple2<Integer, Integer>> result =
input.flatMap(new SubtaskIndexFlatMapper(numberElements));
result.addSink(new CollectionSink<Tuple2<Integer, Integer>>());
return env.getStreamGraph().getJobGraph();
}
private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState(
int parallelism,
int maxParallelism,
int fixedParallelism,
int numberKeys,
int numberElements,
boolean terminateAfterEmission,
int checkpointingInterval) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
env.getConfig().setMaxParallelism(maxParallelism);
env.enableCheckpointing(checkpointingInterval);
env.setRestartStrategy(RestartStrategies.noRestart());
DataStream<Integer> input =
env.addSource(
new SubtaskIndexNonPartitionedStateSource(
numberKeys, numberElements, terminateAfterEmission))
.setParallelism(fixedParallelism)
.keyBy(
new KeySelector<Integer, Integer>() {
private static final long serialVersionUID =
-7952298871120320940L;
@Override
public Integer getKey(Integer value) throws Exception {
return value;
}
});
SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
DataStream<Tuple2<Integer, Integer>> result =
input.flatMap(new SubtaskIndexFlatMapper(numberElements));
result.addSink(new CollectionSink<Tuple2<Integer, Integer>>());
return env.getStreamGraph().getJobGraph();
}
private static class SubtaskIndexSource extends RichParallelSourceFunction<Integer> {
private static final long serialVersionUID = -400066323594122516L;
private final int numberKeys;
private final int numberElements;
private final boolean terminateAfterEmission;
protected int counter = 0;
private boolean running = true;
SubtaskIndexSource(int numberKeys, int numberElements, boolean terminateAfterEmission) {
this.numberKeys = numberKeys;
this.numberElements = numberElements;
this.terminateAfterEmission = terminateAfterEmission;
}
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
final Object lock = ctx.getCheckpointLock();
final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
while (running) {
if (counter < numberElements) {
synchronized (lock) {
for (int value = subtaskIndex;
value < numberKeys;
value += getRuntimeContext().getNumberOfParallelSubtasks()) {
ctx.collect(value);
}
counter++;
}
} else {
if (terminateAfterEmission) {
running = false;
} else {
Thread.sleep(100);
}
}
}
}
@Override
public void cancel() {
running = false;
}
}
private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource
implements ListCheckpointed<Integer> {
private static final long serialVersionUID = 8388073059042040203L;
SubtaskIndexNonPartitionedStateSource(
int numberKeys, int numberElements, boolean terminateAfterEmission) {
super(numberKeys, numberElements, terminateAfterEmission);
}
@Override
public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
return Collections.singletonList(this.counter);
}
@Override
public void restoreState(List<Integer> state) throws Exception {
if (state.isEmpty() || state.size() > 1) {
throw new RuntimeException(
"Test failed due to unexpected recovered state size " + state.size());
}
this.counter = state.get(0);
}
}
private static class SubtaskIndexFlatMapper
extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
implements CheckpointedFunction {
private static final long serialVersionUID = 5273172591283191348L;
private static CountDownLatch workCompletedLatch = new CountDownLatch(1);
private transient ValueState<Integer> counter;
private transient ValueState<Integer> sum;
private final int numberElements;
SubtaskIndexFlatMapper(int numberElements) {
this.numberElements = numberElements;
}
@Override
public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
throws Exception {
int count = counter.value() + 1;
counter.update(count);
int s = sum.value() + value;
sum.update(s);
if (count % numberElements == 0) {
out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
workCompletedLatch.countDown();
}
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
// all managed, nothing to do.
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
counter =
context.getKeyedStateStore()
.getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
sum =
context.getKeyedStateStore()
.getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
}
}
private static class CollectionSink<IN> implements SinkFunction<IN> {
private static Set<Object> elements =
Collections.newSetFromMap(new ConcurrentHashMap<Object, Boolean>());
private static final long serialVersionUID = -1652452958040267745L;
public static <IN> Set<IN> getElementsSet() {
return (Set<IN>) elements;
}
public static void clearElementsSet() {
elements.clear();
}
@Override
public void invoke(IN value) throws Exception {
elements.add(value);
}
}
private static class StateSourceBase extends RichParallelSourceFunction<Integer> {
private static final long serialVersionUID = 7512206069681177940L;
private static CountDownLatch workStartedLatch = new CountDownLatch(1);
private static CountDownLatch canFinishLatch = new CountDownLatch(0);
protected volatile int counter = 0;
protected volatile boolean running = true;
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
final Object lock = ctx.getCheckpointLock();
while (running) {
synchronized (lock) {
++counter;
ctx.collect(1);
}
Thread.sleep(2);
if (counter == 10) {
workStartedLatch.countDown();
}
if (counter >= 500) {
break;
}
}
canFinishLatch.await();
}
@Override
public void cancel() {
running = false;
}
}
private static class NonPartitionedStateSource extends StateSourceBase
implements ListCheckpointed<Integer> {
private static final long serialVersionUID = -8108185918123186841L;
@Override
public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
return Collections.singletonList(this.counter);
}
@Override
public void restoreState(List<Integer> state) throws Exception {
if (!state.isEmpty()) {
this.counter = state.get(0);
}
}
}
private static class PartitionedStateSourceListCheckpointed extends StateSourceBase
implements ListCheckpointed<Integer> {
private static final long serialVersionUID = -4357864582992546L;
private static final int NUM_PARTITIONS = 7;
private static int[] checkCorrectSnapshot;
private static int[] checkCorrectRestore;
@Override
public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
checkCorrectSnapshot[getRuntimeContext().getIndexOfThisSubtask()] = counter;
int div = counter / NUM_PARTITIONS;
int mod = counter % NUM_PARTITIONS;
List<Integer> split = new ArrayList<>();
for (int i = 0; i < NUM_PARTITIONS; ++i) {
int partitionValue = div;
if (mod > 0) {
--mod;
++partitionValue;
}
split.add(partitionValue);
}
return split;
}
@Override
public void restoreState(List<Integer> state) throws Exception {
for (Integer v : state) {
counter += v;
}
checkCorrectRestore[getRuntimeContext().getIndexOfThisSubtask()] = counter;
}
}
private static class PartitionedStateSource extends StateSourceBase
implements CheckpointedFunction {
private static final long serialVersionUID = -359715965103593462L;
private static final int NUM_PARTITIONS = 7;
private transient ListState<Integer> counterPartitions;
private boolean broadcast;
private static int[] checkCorrectSnapshot;
private static int[] checkCorrectRestore;
public PartitionedStateSource(boolean broadcast) {
this.broadcast = broadcast;
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
counterPartitions.clear();
checkCorrectSnapshot[getRuntimeContext().getIndexOfThisSubtask()] = counter;
int div = counter / NUM_PARTITIONS;
int mod = counter % NUM_PARTITIONS;
for (int i = 0; i < NUM_PARTITIONS; ++i) {
int partitionValue = div;
if (mod > 0) {
--mod;
++partitionValue;
}
counterPartitions.add(partitionValue);
}
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
if (broadcast) {
this.counterPartitions =
context.getOperatorStateStore()
.getUnionListState(
new ListStateDescriptor<>(
"counter_partitions", IntSerializer.INSTANCE));
} else {
this.counterPartitions =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<>(
"counter_partitions", IntSerializer.INSTANCE));
}
if (context.isRestored()) {
for (int v : counterPartitions.get()) {
counter += v;
}
checkCorrectRestore[getRuntimeContext().getIndexOfThisSubtask()] = counter;
}
}
}
private static List<JobID> getRunningJobs(ClusterClient<?> client) throws Exception {
Collection<JobStatusMessage> statusMessages = client.listJobs().get();
return statusMessages.stream()
.filter(status -> !status.getJobState().isGloballyTerminalState())
.map(JobStatusMessage::getJobId)
.collect(Collectors.toList());
}
}