blob: b9c707f8e5d6f9d4121c06ae735ab72b875d727a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recovery;
import org.apache.flink.api.common.ExecutionMode;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.configuration.UnmodifiableConfiguration;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.messages.webmonitor.JobIdsWithStatusOverview;
import org.apache.flink.runtime.messages.webmonitor.JobIdsWithStatusOverview.JobIdWithStatus;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.rest.RestClient;
import org.apache.flink.runtime.rest.RestClientConfiguration;
import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
import org.apache.flink.runtime.rest.messages.JobIdsWithStatusesOverviewHeaders;
import org.apache.flink.runtime.rest.messages.JobMessageParameters;
import org.apache.flink.runtime.rest.messages.JobVertexDetailsHeaders;
import org.apache.flink.runtime.rest.messages.JobVertexDetailsInfo;
import org.apache.flink.runtime.rest.messages.JobVertexMessageParameters;
import org.apache.flink.runtime.rest.messages.MessageHeaders;
import org.apache.flink.runtime.rest.messages.MessageParameters;
import org.apache.flink.runtime.rest.messages.ResponseBody;
import org.apache.flink.runtime.rest.messages.job.JobDetailsHeaders;
import org.apache.flink.runtime.rest.messages.job.JobDetailsInfo;
import org.apache.flink.runtime.rest.messages.job.SubtaskExecutionAttemptDetailsInfo;
import org.apache.flink.runtime.testutils.MiniClusterResource;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.test.util.TestEnvironment;
import org.apache.flink.util.Collector;
import org.apache.flink.util.ConfigurationException;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkException;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.TemporaryClassLoaderContext;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.ExecutorThreadFactory;
import org.apache.flink.util.concurrent.FutureUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.GuardedBy;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import static org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategyFactoryLoader.PIPELINED_REGION_RESTART_STRATEGY_NAME;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
/**
* IT case for fine-grained recovery of batch jobs.
*
* <p>The test activates the region fail-over strategy to restart only failed producers. The test
* job is a sequence of non-parallel mappers. Each mapper writes to a blocking partition which means
* the next mapper starts when the previous is done. The mappers are not chained into one task which
* makes them separate fail-over regions.
*
* <p>The test verifies the fine-grained recovery by including one failure after a random record for
* each failure strategy in all mappers and comparing expected number of each mapper restarts
* against the actual restarts. There are multiple failure strategies:
*
* <ul>
* <li>The {@link ExceptionFailureStrategy} throws an exception in the user function code. Since
* all mappers are connected via blocking partitions, which should be re-used on failure, and
* the consumer of the mapper wasn't deployed yet, as the consumed partition was not fully
* produced yet, only the failed mapper should actually restart.
* <li>The {@link TaskExecutorFailureStrategy} abruptly shuts down the task executor. This leads
* to the loss of all previously completed and the in-progress mapper result partitions. The
* fail-over strategy should restart the current in-progress mapper which will get the {@link
* PartitionNotFoundException} because the previous result becomes unavailable and the
* previous mapper has to be restarted as well. The same should happen subsequently with all
* previous mappers. When the source is recomputed, all mappers has to be restarted again to
* recalculate their lost results.
* </ul>
*/
public class BatchFineGrainedRecoveryITCase extends TestLogger {
private static final Logger LOG = LoggerFactory.getLogger(BatchFineGrainedRecoveryITCase.class);
private static final int EMITTED_RECORD_NUMBER = 1000;
private static final int MAP_NUMBER = 3;
private static final String MAP_PARTITION_TEST_PARTITION_MAPPER =
"MapPartition (Test partition mapper ";
private static final Pattern MAPPER_NUMBER_IN_TASK_NAME_PATTERN =
Pattern.compile("MapPartition \\(Test partition mapper (\\d+)\\)");
/**
* Number of job failures for all mappers due to backtracking when the produced partitions get
* lost.
*
* <p>Each mapper failure produces number of backtracking failures (partition not found) which
* is the mapper index + 1, because all previous producers have to be restarted and they firstly
* will not find the previous result.
*/
private static final int ALL_MAPPERS_BACKTRACK_FAILURES =
IntStream.range(0, MAP_NUMBER + 1).sum();
/**
* Max number of job failures.
*
* <p>For each possible mapper failure, it is all possible backtracking failures plus the
* generated failures themselves of each type.
*/
private static final int MAX_JOB_RESTART_ATTEMPTS =
ALL_MAPPERS_BACKTRACK_FAILURES + 2 * MAP_NUMBER;
/** Expected attempt number for each mapper. */
private static final int[] EXPECTED_MAP_ATTEMPT_NUMBERS =
IntStream.range(0, MAP_NUMBER)
.map(
i ->
// exception failure:
1
+ // this mapper
// TM failure:
(MAP_NUMBER - i - 1)
+ // subsequent mappers after PartitionNotFoundException
1
+ // this mapper
1) // this mapper after PartitionNotFoundException
.toArray();
private static final String TASK_NAME_PREFIX = "Test partition mapper ";
private static final List<Long> EXPECTED_JOB_OUTPUT =
LongStream.range(MAP_NUMBER, EMITTED_RECORD_NUMBER + MAP_NUMBER)
.boxed()
.collect(Collectors.toList());
@ClassRule
public static final MiniClusterResource MINI_CLUSTER_RESOURCE =
new MiniClusterResource(
new MiniClusterResourceConfiguration.Builder()
.setConfiguration(createConfiguration())
.setNumberTaskManagers(1)
.setNumberSlotsPerTaskManager(1)
.build());
private static MiniCluster miniCluster;
private static MiniClusterClient client;
private static AtomicInteger lastTaskManagerIndexInMiniCluster;
private static final Random rnd = new Random();
private static GlobalMapFailureTracker failureTracker;
@SuppressWarnings("OverlyBroadThrowsClause")
@Before
public void setup() throws Exception {
miniCluster = MINI_CLUSTER_RESOURCE.getMiniCluster();
client = new MiniClusterClient(miniCluster);
lastTaskManagerIndexInMiniCluster = new AtomicInteger(0);
failureTracker = new GlobalMapFailureTracker(MAP_NUMBER);
}
@After
public void teardown() throws Exception {
if (client != null) {
client.close();
}
}
@Test
public void testProgram() throws Exception {
ExecutionEnvironment env = createExecutionEnvironment();
DataSet<Long> input = env.generateSequence(0, EMITTED_RECORD_NUMBER - 1);
for (int trackingIndex = 0; trackingIndex < MAP_NUMBER; trackingIndex++) {
input =
input.mapPartition(
new TestPartitionMapper(
trackingIndex, createFailureStrategy(trackingIndex)))
.name(TASK_NAME_PREFIX + trackingIndex);
}
assertThat(input.collect(), is(EXPECTED_JOB_OUTPUT));
failureTracker.verify(getMapperAttempts());
}
private static Configuration createConfiguration() {
Configuration configuration = new Configuration();
configuration.setString(
JobManagerOptions.EXECUTION_FAILOVER_STRATEGY,
PIPELINED_REGION_RESTART_STRATEGY_NAME);
return configuration;
}
private static FailureStrategy createFailureStrategy(int trackingIndex) {
int failWithExceptionAfterNumberOfProcessedRecords = rnd.nextInt(EMITTED_RECORD_NUMBER) + 1;
int failTaskExecutorAfterNumberOfProcessedRecords = rnd.nextInt(EMITTED_RECORD_NUMBER) + 1;
// it has to fail only once during one mapper run so that different failure strategies do
// not mess up each other stats
FailureStrategy failureStrategy =
new OneTimeFailureStrategy(
new JoinedFailureStrategy(
new GloballyTrackingFailureStrategy(
new ExceptionFailureStrategy(
failWithExceptionAfterNumberOfProcessedRecords)),
new GloballyTrackingFailureStrategy(
new TaskExecutorFailureStrategy(
failTaskExecutorAfterNumberOfProcessedRecords))));
LOG.info("FailureStrategy for the mapper {}: {}", trackingIndex, failureStrategy);
return failureStrategy;
}
private static ExecutionEnvironment createExecutionEnvironment() {
@SuppressWarnings("StaticVariableUsedBeforeInitialization")
ExecutionEnvironment env = new TestEnvironment(miniCluster, 1, true);
env.setRestartStrategy(
RestartStrategies.fixedDelayRestart(
MAX_JOB_RESTART_ATTEMPTS, Time.milliseconds(10)));
env.getConfig()
.setExecutionMode(
ExecutionMode.BATCH_FORCED); // forces all partitions to be blocking
return env;
}
@SuppressWarnings({"StaticVariableUsedBeforeInitialization", "OverlyBroadThrowsClause"})
private static void restartTaskManager() throws Exception {
int tmi = lastTaskManagerIndexInMiniCluster.getAndIncrement();
try {
miniCluster.terminateTaskManager(tmi).get();
} finally {
miniCluster.startTaskManager();
}
}
private static int[] getMapperAttempts() {
int[] attempts = new int[MAP_NUMBER];
//noinspection StaticVariableUsedBeforeInitialization
client.getInternalTaskInfos().stream()
.filter(t -> t.name.startsWith(MAP_PARTITION_TEST_PARTITION_MAPPER))
.forEach(t -> attempts[parseMapperNumberFromTaskName(t.name)] = t.attempt);
return attempts;
}
private static int parseMapperNumberFromTaskName(String name) {
Matcher m = MAPPER_NUMBER_IN_TASK_NAME_PATTERN.matcher(name);
if (m.matches()) {
return Integer.parseInt(m.group(1));
} else {
throw new FlinkRuntimeException(
"Failed to find mapper number in its task name: " + name);
}
}
@FunctionalInterface
private interface FailureStrategy extends Serializable {
/**
* Decides whether to fail and fails the task implicitly or by throwing an exception.
*
* @param trackingIndex index of the mapper task in the sequence
* @return {@code true} if task is failed implicitly or {@code false} if task is not failed
* @throws Exception To fail the task explicitly
*/
boolean failOrNot(int trackingIndex) throws Exception;
}
private static class OneTimeFailureStrategy implements FailureStrategy {
private static final long serialVersionUID = 1L;
private final FailureStrategy wrappedFailureStrategy;
private transient boolean failed;
private OneTimeFailureStrategy(FailureStrategy wrappedFailureStrategy) {
this.wrappedFailureStrategy = wrappedFailureStrategy;
}
@Override
public boolean failOrNot(int trackingIndex) throws Exception {
if (!failed) {
try {
boolean failedNow = wrappedFailureStrategy.failOrNot(trackingIndex);
failed = failedNow;
return failedNow;
} catch (Exception e) {
failed = true;
throw e;
}
}
return false;
}
@Override
public String toString() {
return "FailingOnce{" + wrappedFailureStrategy + '}';
}
}
private static class JoinedFailureStrategy implements FailureStrategy {
private static final long serialVersionUID = 1L;
private final FailureStrategy[] failureStrategies;
private JoinedFailureStrategy(FailureStrategy... failureStrategies) {
this.failureStrategies = failureStrategies;
}
@Override
public boolean failOrNot(int trackingIndex) throws Exception {
for (FailureStrategy failureStrategy : failureStrategies) {
if (failureStrategy.failOrNot(trackingIndex)) {
return true;
}
}
return false;
}
@Override
public String toString() {
return String.join(
" or ",
(Iterable<String>)
() ->
Arrays.stream(failureStrategies)
.map(Object::toString)
.iterator());
}
}
private static class GloballyTrackingFailureStrategy implements FailureStrategy {
private static final long serialVersionUID = 1L;
private final FailureStrategy wrappedFailureStrategy;
private GloballyTrackingFailureStrategy(FailureStrategy wrappedFailureStrategy) {
this.wrappedFailureStrategy = wrappedFailureStrategy;
}
@Override
public boolean failOrNot(int trackingIndex) throws Exception {
return failureTracker.failOrNot(trackingIndex, wrappedFailureStrategy);
}
@Override
public String toString() {
return "Tracked{" + wrappedFailureStrategy + '}';
}
}
private static class ExceptionFailureStrategy
extends AbstractOnceAfterCallNumberFailureStrategy {
private static final long serialVersionUID = 1L;
private ExceptionFailureStrategy(int failAfterCallNumber) {
super(failAfterCallNumber);
}
@Override
void fail(int trackingIndex) throws FlinkException {
throw new FlinkException("BAGA-BOOM!!! The user function generated test failure.");
}
}
private static class TaskExecutorFailureStrategy
extends AbstractOnceAfterCallNumberFailureStrategy {
private static final long serialVersionUID = 1L;
private TaskExecutorFailureStrategy(int failAfterCallNumber) {
super(failAfterCallNumber);
}
@Override
void fail(int trackingIndex) throws Exception {
//noinspection OverlyBroadCatchBlock
try (TemporaryClassLoaderContext unused =
TemporaryClassLoaderContext.of(ClassLoader.getSystemClassLoader())) {
try {
restartTaskManager();
} catch (InterruptedException e) {
// ignore the exception, task should have been failed while stopping TM
Thread.currentThread().interrupt();
} catch (Throwable t) {
failureTracker.unrelatedFailure(t);
throw t;
}
}
}
}
private abstract static class AbstractOnceAfterCallNumberFailureStrategy
implements FailureStrategy {
private static final long serialVersionUID = 1L;
private final UUID id;
private final int failAfterCallNumber;
private transient int callCounter;
private AbstractOnceAfterCallNumberFailureStrategy(int failAfterCallNumber) {
this.failAfterCallNumber = failAfterCallNumber;
id = UUID.randomUUID();
}
@Override
public boolean failOrNot(int trackingIndex) throws Exception {
callCounter++;
boolean generateFailure = callCounter == failAfterCallNumber;
if (generateFailure) {
fail(trackingIndex);
}
return generateFailure;
}
abstract void fail(int trackingIndex) throws Exception;
@Override
public String toString() {
return this.getClass().getSimpleName()
+ " (fail after "
+ failAfterCallNumber
+ " calls)";
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
return Objects.equals(id, ((AbstractOnceAfterCallNumberFailureStrategy) o).id);
}
@Override
public int hashCode() {
return id.hashCode();
}
}
private static class GlobalMapFailureTracker {
private final List<Set<FailureStrategy>> mapFailures;
private final Object classLock = new Object();
@GuardedBy("classLock")
private Throwable unexpectedFailure;
private GlobalMapFailureTracker(int numberOfMappers) {
mapFailures = new ArrayList<>(numberOfMappers);
IntStream.range(0, numberOfMappers).forEach(i -> addNewMapper());
}
private int addNewMapper() {
mapFailures.add(new HashSet<>(2));
return mapFailures.size() - 1;
}
private boolean failOrNot(int index, FailureStrategy failureStrategy) throws Exception {
boolean alreadyFailed = mapFailures.get(index).contains(failureStrategy);
boolean failedNow = false;
try {
failedNow = !alreadyFailed && failureStrategy.failOrNot(index);
} catch (Exception e) {
failedNow = true;
throw e;
} finally {
if (failedNow) {
mapFailures.get(index).add(failureStrategy);
}
}
return failedNow;
}
private void unrelatedFailure(Throwable failure) {
synchronized (classLock) {
unexpectedFailure = ExceptionUtils.firstOrSuppressed(failure, unexpectedFailure);
}
}
private void verify(int[] mapAttemptNumbers) {
synchronized (classLock) {
if (unexpectedFailure != null) {
throw new AssertionError(
"Test failed due to unexpected exception.", unexpectedFailure);
}
}
assertThat(mapAttemptNumbers, is(EXPECTED_MAP_ATTEMPT_NUMBERS));
}
}
private static class TestPartitionMapper extends RichMapPartitionFunction<Long, Long> {
private static final long serialVersionUID = 1L;
private final int trackingIndex;
private final FailureStrategy failureStrategy;
private TestPartitionMapper(int trackingIndex, FailureStrategy failureStrategy) {
this.trackingIndex = trackingIndex;
this.failureStrategy = failureStrategy;
}
@Override
public void mapPartition(Iterable<Long> values, Collector<Long> out) throws Exception {
for (Long value : values) {
failureStrategy.failOrNot(trackingIndex);
out.collect(value + 1);
}
}
}
private static class MiniClusterClient implements AutoCloseable {
private final RestClient restClient;
private final ExecutorService executorService;
private final URI restAddress;
private MiniClusterClient(MiniCluster miniCluster) throws ConfigurationException {
restAddress = miniCluster.getRestAddress().join();
executorService =
Executors.newSingleThreadScheduledExecutor(
new ExecutorThreadFactory("Flink-RestClient-IO"));
restClient = createRestClient();
}
private RestClient createRestClient() throws ConfigurationException {
return new RestClient(
RestClientConfiguration.fromConfiguration(
new UnmodifiableConfiguration(new Configuration())),
executorService);
}
private List<InternalTaskInfo> getInternalTaskInfos() {
return getJobs().stream()
.flatMap(
jobId ->
getJobDetails(jobId).join().getJobVertexInfos().stream()
.map(info -> Tuple2.of(jobId, info)))
.flatMap(
vertexInfoWithJobId ->
getJobVertexDetailsInfo(
vertexInfoWithJobId.f0,
vertexInfoWithJobId.f1.getJobVertexID())
.getSubtasks().stream()
.map(
subtask ->
new InternalTaskInfo(
vertexInfoWithJobId.f1
.getName(),
subtask)))
.collect(Collectors.toList());
}
private Collection<JobID> getJobs() {
JobIdsWithStatusOverview jobIds =
sendRequest(
JobIdsWithStatusesOverviewHeaders.getInstance(),
EmptyMessageParameters.getInstance())
.join();
return jobIds.getJobsWithStatus().stream()
.map(JobIdWithStatus::getJobId)
.collect(Collectors.toList());
}
private CompletableFuture<JobDetailsInfo> getJobDetails(JobID jobId) {
JobMessageParameters params = new JobMessageParameters();
params.jobPathParameter.resolve(jobId);
return sendRequest(JobDetailsHeaders.getInstance(), params);
}
private JobVertexDetailsInfo getJobVertexDetailsInfo(JobID jobId, JobVertexID jobVertexID) {
JobVertexDetailsHeaders detailsHeaders = JobVertexDetailsHeaders.getInstance();
JobVertexMessageParameters params = new JobVertexMessageParameters();
params.jobPathParameter.resolve(jobId);
params.jobVertexIdPathParameter.resolve(jobVertexID);
return sendRequest(detailsHeaders, params).join();
}
private <
M extends MessageHeaders<EmptyRequestBody, P, U>,
U extends MessageParameters,
P extends ResponseBody>
CompletableFuture<P> sendRequest(M messageHeaders, U messageParameters) {
try {
return restClient.sendRequest(
restAddress.getHost(),
restAddress.getPort(),
messageHeaders,
messageParameters,
EmptyRequestBody.getInstance());
} catch (IOException e) {
return FutureUtils.completedExceptionally(e);
}
}
@Override
public void close() throws Exception {
restClient.close();
executorService.shutdownNow();
}
}
private static class InternalTaskInfo {
private final String name;
private final int attempt;
private InternalTaskInfo(String name, SubtaskExecutionAttemptDetailsInfo vertexTaskDetail) {
this.name = name;
this.attempt = vertexTaskDetail.getAttempt();
}
@Override
public String toString() {
return name + " (Attempt #" + attempt + ')';
}
}
}