blob: 6cd7adaedc58e69fd49a51deb20f5fedd15a368e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.checkpointing;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.HighAvailabilityOptions;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.CheckpointsCleaner;
import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
import org.apache.flink.runtime.checkpoint.PerJobCheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
import org.apache.flink.runtime.checkpoint.StandaloneCompletedCheckpointStore;
import org.apache.flink.runtime.highavailability.HighAvailabilityServices;
import org.apache.flink.runtime.highavailability.HighAvailabilityServicesFactory;
import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServices;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStreamUtils;
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.util.TestLogger;
import org.apache.commons.collections.CollectionUtils;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;
import static org.apache.flink.test.util.TestUtils.submitJobAndWaitForResult;
import static org.junit.Assert.assertEquals;
/** Tests for region failover with multi regions. */
public class RegionFailoverITCase extends TestLogger {
private static final int FAIL_BASE = 1000;
private static final int NUM_OF_REGIONS = 3;
private static final int MAX_PARALLELISM = 2 * NUM_OF_REGIONS;
private static final Set<Integer> EXPECTED_INDICES_MULTI_REGION =
IntStream.range(0, NUM_OF_REGIONS).boxed().collect(Collectors.toSet());
private static final Set<Integer> EXPECTED_INDICES_SINGLE_REGION = Collections.singleton(0);
private static final int NUM_OF_RESTARTS = 3;
private static final int NUM_ELEMENTS = FAIL_BASE * 10;
private static final String SINGLE_REGION_SOURCE_NAME = "single-source";
private static final String MULTI_REGION_SOURCE_NAME = "multi-source";
private static AtomicLong lastCompletedCheckpointId = new AtomicLong(0);
private static AtomicInteger numCompletedCheckpoints = new AtomicInteger(0);
private static AtomicInteger jobFailedCnt = new AtomicInteger(0);
private static Map<Long, Integer> snapshotIndicesOfSubTask = new HashMap<>();
private static MiniClusterWithClientResource cluster;
private static boolean restoredState = false;
@ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder();
@Before
public void setup() throws Exception {
Configuration configuration = new Configuration();
configuration.setString(JobManagerOptions.EXECUTION_FAILOVER_STRATEGY, "region");
configuration.setString(HighAvailabilityOptions.HA_MODE, TestingHAFactory.class.getName());
cluster =
new MiniClusterWithClientResource(
new MiniClusterResourceConfiguration.Builder()
.setConfiguration(configuration)
.setNumberTaskManagers(2)
.setNumberSlotsPerTaskManager(2)
.build());
cluster.before();
jobFailedCnt.set(0);
numCompletedCheckpoints.set(0);
}
@AfterClass
public static void shutDownExistingCluster() {
if (cluster != null) {
cluster.after();
cluster = null;
}
}
/**
* Tests that a simple job (Source -> Map) with multi regions could restore with operator state.
*
* <p>The last subtask of Map function in the 1st stream graph would fail {@code
* NUM_OF_RESTARTS} times, and it will verify whether the restored state is identical to last
* completed checkpoint's.
*/
@Test(timeout = 60000)
public void testMultiRegionFailover() {
try {
JobGraph jobGraph = createJobGraph();
ClusterClient<?> client = cluster.getClusterClient();
submitJobAndWaitForResult(client, jobGraph, getClass().getClassLoader());
verifyAfterJobExecuted();
} catch (Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}
private void verifyAfterJobExecuted() {
Assert.assertTrue(
"The test multi-region job has never ever restored state.", restoredState);
int keyCount = 0;
for (Map<Integer, Integer> map : ValidatingSink.maps) {
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
assertEquals(4 * entry.getKey() + 1, (int) entry.getValue());
keyCount += 1;
}
}
assertEquals(NUM_ELEMENTS / 2, keyCount);
}
private JobGraph createJobGraph() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(NUM_OF_REGIONS);
env.setMaxParallelism(MAX_PARALLELISM);
env.enableCheckpointing(200, CheckpointingMode.EXACTLY_ONCE);
env.getCheckpointConfig()
.enableExternalizedCheckpoints(
CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
env.disableOperatorChaining();
// Use DataStreamUtils#reinterpretAsKeyed to avoid merge regions and this stream graph would
// exist num of 'NUM_OF_REGIONS' individual regions.
DataStreamUtils.reinterpretAsKeyedStream(
env.addSource(
new StringGeneratingSourceFunction(
NUM_ELEMENTS, NUM_ELEMENTS / NUM_OF_RESTARTS))
.name(MULTI_REGION_SOURCE_NAME)
.setParallelism(NUM_OF_REGIONS),
(KeySelector<Tuple2<Integer, Integer>, Integer>) value -> value.f0,
TypeInformation.of(Integer.class))
.map(new FailingMapperFunction(NUM_OF_RESTARTS))
.setParallelism(NUM_OF_REGIONS)
.addSink(new ValidatingSink())
.setParallelism(NUM_OF_REGIONS);
// another stream graph totally disconnected with the above one.
env.addSource(
new StringGeneratingSourceFunction(
NUM_ELEMENTS, NUM_ELEMENTS / NUM_OF_RESTARTS))
.name(SINGLE_REGION_SOURCE_NAME)
.setParallelism(1)
.map((MapFunction<Tuple2<Integer, Integer>, Object>) value -> value)
.setParallelism(1);
return env.getStreamGraph().getJobGraph();
}
private static class StringGeneratingSourceFunction
extends RichParallelSourceFunction<Tuple2<Integer, Integer>>
implements CheckpointedFunction {
private static final long serialVersionUID = 1L;
private final long numElements;
private final long checkpointLatestAt;
private int index = -1;
private int lastRegionIndex = -1;
private volatile boolean isRunning = true;
private ListState<Integer> listState;
private static final ListStateDescriptor<Integer> stateDescriptor =
new ListStateDescriptor<>("list-1", Integer.class);
private ListState<Integer> unionListState;
private static final ListStateDescriptor<Integer> unionStateDescriptor =
new ListStateDescriptor<>("list-2", Integer.class);
StringGeneratingSourceFunction(long numElements, long checkpointLatestAt) {
this.numElements = numElements;
this.checkpointLatestAt = checkpointLatestAt;
}
@Override
public void run(SourceContext<Tuple2<Integer, Integer>> ctx) throws Exception {
if (index < 0) {
// not been restored, so initialize
index = 0;
}
int subTaskIndex = getRuntimeContext().getIndexOfThisSubtask();
while (isRunning && index < numElements) {
synchronized (ctx.getCheckpointLock()) {
int key = index / 2;
int forwardTaskIndex =
KeyGroupRangeAssignment.assignKeyToParallelOperator(
key, MAX_PARALLELISM, NUM_OF_REGIONS);
// pre-partition output keys
if (forwardTaskIndex == subTaskIndex) {
// we would send data with the same key twice.
ctx.collect(Tuple2.of(key, index));
}
index += 1;
}
if (numCompletedCheckpoints.get() < 3) {
// not yet completed enough checkpoints, so slow down
if (index < checkpointLatestAt) {
// mild slow down
Thread.sleep(1);
} else {
// wait until the checkpoints are completed
while (isRunning && numCompletedCheckpoints.get() < 3) {
Thread.sleep(300);
}
}
}
if (jobFailedCnt.get() < NUM_OF_RESTARTS) {
// slow down if job has not failed for 'NUM_OF_RESTARTS' times.
Thread.sleep(1);
}
}
}
@Override
public void cancel() {
isRunning = false;
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
if (indexOfThisSubtask != 0) {
listState.clear();
listState.add(index);
if (indexOfThisSubtask == NUM_OF_REGIONS - 1) {
lastRegionIndex = index;
snapshotIndicesOfSubTask.put(context.getCheckpointId(), lastRegionIndex);
}
}
unionListState.clear();
unionListState.add(indexOfThisSubtask);
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
if (context.isRestored()) {
restoredState = true;
unionListState =
context.getOperatorStateStore().getUnionListState(unionStateDescriptor);
Set<Integer> actualIndices =
StreamSupport.stream(unionListState.get().spliterator(), false)
.collect(Collectors.toSet());
if (getRuntimeContext().getTaskName().contains(SINGLE_REGION_SOURCE_NAME)) {
Assert.assertTrue(
CollectionUtils.isEqualCollection(
EXPECTED_INDICES_SINGLE_REGION, actualIndices));
} else {
Assert.assertTrue(
CollectionUtils.isEqualCollection(
EXPECTED_INDICES_MULTI_REGION, actualIndices));
}
if (indexOfThisSubtask == 0) {
listState = context.getOperatorStateStore().getListState(stateDescriptor);
Assert.assertTrue(
"list state should be empty for subtask-0",
((List<Integer>) listState.get()).isEmpty());
} else {
listState = context.getOperatorStateStore().getListState(stateDescriptor);
Assert.assertTrue(
"list state should not be empty for subtask-" + indexOfThisSubtask,
((List<Integer>) listState.get()).size() > 0);
if (indexOfThisSubtask == NUM_OF_REGIONS - 1) {
index = listState.get().iterator().next();
if (index
!= snapshotIndicesOfSubTask.get(lastCompletedCheckpointId.get())) {
throw new RuntimeException(
"Test failed due to unexpected recovered index: "
+ index
+ ", while last completed checkpoint record index: "
+ snapshotIndicesOfSubTask.get(
lastCompletedCheckpointId.get()));
}
}
}
} else {
unionListState =
context.getOperatorStateStore().getUnionListState(unionStateDescriptor);
if (indexOfThisSubtask != 0) {
listState = context.getOperatorStateStore().getListState(stateDescriptor);
}
}
}
}
private static class FailingMapperFunction
extends RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
private final int restartTimes;
private ValueState<Integer> valueState;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
valueState =
getRuntimeContext()
.getState(new ValueStateDescriptor<>("value", Integer.class));
}
FailingMapperFunction(int restartTimes) {
this.restartTimes = restartTimes;
}
@Override
public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> input) throws Exception {
int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
if (input.f1 > FAIL_BASE * (jobFailedCnt.get() + 1)) {
// we would let region-0 to failover first
if (jobFailedCnt.get() < 1 && indexOfThisSubtask == 0) {
jobFailedCnt.incrementAndGet();
throw new TestException();
}
// then let last region to failover
if (jobFailedCnt.get() < restartTimes && indexOfThisSubtask == NUM_OF_REGIONS - 1) {
jobFailedCnt.incrementAndGet();
throw new TestException();
}
}
// take input (1, 2) and (1, 3) for example, we would finally emit (1, 5) out with the
// usage of keyed state.
Integer value = valueState.value();
if (value == null) {
valueState.update(input.f1);
return input;
} else {
return Tuple2.of(input.f0, value + input.f1);
}
}
}
private static class ValidatingSink extends RichSinkFunction<Tuple2<Integer, Integer>>
implements ListCheckpointed<HashMap<Integer, Integer>> {
@SuppressWarnings("unchecked")
private static Map<Integer, Integer>[] maps =
(Map<Integer, Integer>[]) new Map<?, ?>[NUM_OF_REGIONS];
private HashMap<Integer, Integer> counts = new HashMap<>();
@Override
public void invoke(Tuple2<Integer, Integer> input) {
counts.merge(input.f0, input.f1, Math::max);
}
@Override
public void close() throws Exception {
maps[getRuntimeContext().getIndexOfThisSubtask()] = counts;
}
@Override
public List<HashMap<Integer, Integer>> snapshotState(long checkpointId, long timestamp)
throws Exception {
return Collections.singletonList(this.counts);
}
@Override
public void restoreState(List<HashMap<Integer, Integer>> state) throws Exception {
if (state.size() != 1) {
throw new RuntimeException(
"Test failed due to unexpected recovered state size " + state.size());
}
this.counts.putAll(state.get(0));
}
}
private static class TestException extends IOException {
private static final long serialVersionUID = 1L;
}
private static class TestingHaServices extends EmbeddedHaServices {
private final CheckpointRecoveryFactory checkpointRecoveryFactory;
TestingHaServices(CheckpointRecoveryFactory checkpointRecoveryFactory, Executor executor) {
super(executor);
this.checkpointRecoveryFactory = checkpointRecoveryFactory;
}
@Override
public CheckpointRecoveryFactory getCheckpointRecoveryFactory() {
return checkpointRecoveryFactory;
}
}
/**
* An extension of {@link StandaloneCompletedCheckpointStore} which would record information of
* last completed checkpoint id and the number of completed checkpoints.
*/
private static class TestingCompletedCheckpointStore
extends StandaloneCompletedCheckpointStore {
TestingCompletedCheckpointStore() {
super(1);
}
@Override
public void addCheckpoint(
CompletedCheckpoint checkpoint,
CheckpointsCleaner checkpointsCleaner,
Runnable postCleanup)
throws Exception {
super.addCheckpoint(checkpoint, checkpointsCleaner, postCleanup);
// we record the information when adding completed checkpoint instead of
// 'notifyCheckpointComplete' invoked
// on task side to avoid race condition. See FLINK-13601.
lastCompletedCheckpointId.set(checkpoint.getCheckpointID());
numCompletedCheckpoints.incrementAndGet();
}
}
/** Testing HA factory which needs to be public in order to be instantiatable. */
public static class TestingHAFactory implements HighAvailabilityServicesFactory {
@Override
public HighAvailabilityServices createHAServices(
Configuration configuration, Executor executor) {
return new TestingHaServices(
PerJobCheckpointRecoveryFactory.useSameServicesForAllJobs(
new TestingCompletedCheckpointStore(),
new StandaloneCheckpointIDCounter()),
executor);
}
}
}