| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.test.checkpointing; |
| |
| import org.apache.flink.api.common.JobExecutionResult; |
| import org.apache.flink.api.common.JobSubmissionResult; |
| import org.apache.flink.api.common.accumulators.IntCounter; |
| import org.apache.flink.api.common.accumulators.LongCounter; |
| import org.apache.flink.api.common.functions.FilterFunction; |
| import org.apache.flink.api.common.functions.Partitioner; |
| import org.apache.flink.api.common.functions.RichMapFunction; |
| import org.apache.flink.api.common.restartstrategy.RestartStrategies; |
| import org.apache.flink.api.common.state.CheckpointListener; |
| import org.apache.flink.api.common.state.ListState; |
| import org.apache.flink.api.common.state.ListStateDescriptor; |
| import org.apache.flink.api.common.time.Deadline; |
| import org.apache.flink.api.common.time.Time; |
| import org.apache.flink.api.connector.source.Boundedness; |
| import org.apache.flink.api.connector.source.ReaderOutput; |
| import org.apache.flink.api.connector.source.Source; |
| import org.apache.flink.api.connector.source.SourceEvent; |
| import org.apache.flink.api.connector.source.SourceReader; |
| import org.apache.flink.api.connector.source.SourceReaderContext; |
| import org.apache.flink.api.connector.source.SourceSplit; |
| import org.apache.flink.api.connector.source.SplitEnumerator; |
| import org.apache.flink.api.connector.source.SplitEnumeratorContext; |
| import org.apache.flink.api.connector.source.SplitsAssignment; |
| import org.apache.flink.configuration.AkkaOptions; |
| import org.apache.flink.configuration.CheckpointingOptions; |
| import org.apache.flink.configuration.Configuration; |
| import org.apache.flink.configuration.MemorySize; |
| import org.apache.flink.configuration.NettyShuffleEnvironmentOptions; |
| import org.apache.flink.configuration.StateBackendOptions; |
| import org.apache.flink.configuration.TaskManagerOptions; |
| import org.apache.flink.core.io.InputStatus; |
| import org.apache.flink.core.io.SimpleVersionedSerializer; |
| import org.apache.flink.runtime.jobgraph.SavepointConfigOptions; |
| import org.apache.flink.runtime.state.FunctionInitializationContext; |
| import org.apache.flink.runtime.state.FunctionSnapshotContext; |
| import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; |
| import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; |
| import org.apache.flink.streaming.api.environment.CheckpointConfig; |
| import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; |
| import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction; |
| import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; |
| import org.apache.flink.streaming.api.graph.StreamGraph; |
| import org.apache.flink.test.util.MiniClusterWithClientResource; |
| import org.apache.flink.testutils.junit.FailsWithAdaptiveScheduler; |
| import org.apache.flink.util.Collector; |
| import org.apache.flink.util.ExceptionUtils; |
| import org.apache.flink.util.TestLogger; |
| import org.apache.flink.util.concurrent.FutureUtils; |
| |
| import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables; |
| import org.apache.flink.shaded.netty4.io.netty.util.internal.PlatformDependent; |
| |
| import org.junit.Rule; |
| import org.junit.experimental.categories.Category; |
| import org.junit.rules.ErrorCollector; |
| import org.junit.rules.TemporaryFolder; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import javax.annotation.Nullable; |
| |
| import java.io.File; |
| import java.io.IOException; |
| import java.nio.ByteBuffer; |
| import java.nio.file.Files; |
| import java.nio.file.Path; |
| import java.nio.file.attribute.BasicFileAttributes; |
| import java.time.Duration; |
| import java.util.ArrayList; |
| import java.util.Comparator; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Optional; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.function.Function; |
| import java.util.stream.Collectors; |
| import java.util.stream.IntStream; |
| |
| import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.CHECKPOINT_DIR_PREFIX; |
| import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.METADATA_FILE_NAME; |
| import static org.apache.flink.shaded.guava30.com.google.common.collect.Iterables.getOnlyElement; |
| import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.CHECKPOINTING_TIMEOUT; |
| import static org.apache.flink.util.Preconditions.checkState; |
| |
| /** Base class for tests related to unaligned checkpoints. */ |
| @Category(FailsWithAdaptiveScheduler.class) // FLINK-21689 |
| public abstract class UnalignedCheckpointTestBase extends TestLogger { |
| protected static final Logger LOG = LoggerFactory.getLogger(UnalignedCheckpointTestBase.class); |
| protected static final String NUM_INPUTS = "inputs_"; |
| protected static final String NUM_OUTPUTS = "outputs"; |
| protected static final String NUM_OUT_OF_ORDER = "outOfOrder"; |
| protected static final String NUM_FAILURES = "failures"; |
| protected static final String NUM_DUPLICATES = "duplicates"; |
| protected static final String NUM_LOST = "lost"; |
| protected static final int BUFFER_PER_CHANNEL = 1; |
| /** For multi-gate tests. */ |
| protected static final int NUM_SOURCES = 3; |
| |
| private static final long HEADER = 0xABCDEAFCL << 32; |
| private static final long HEADER_MASK = 0xFFFFFFFFL << 32; |
| |
| @Rule public final TemporaryFolder temp = new TemporaryFolder(); |
| |
| @Rule public ErrorCollector collector = new ErrorCollector(); |
| |
| @Nullable |
| protected File execute(UnalignedSettings settings) throws Exception { |
| final File checkpointDir = temp.newFolder(); |
| Configuration conf = settings.getConfiguration(checkpointDir); |
| |
| final StreamGraph streamGraph = getStreamGraph(settings, conf); |
| final int requiredSlots = |
| streamGraph.getStreamNodes().stream() |
| .mapToInt(node -> node.getParallelism()) |
| .reduce(0, settings.channelType.slotSharing ? Integer::max : Integer::sum); |
| int numberTaskmanagers = settings.channelType.slotsToTaskManagers.apply(requiredSlots); |
| final int slotsPerTM = (requiredSlots + numberTaskmanagers - 1) / numberTaskmanagers; |
| final MiniClusterWithClientResource miniCluster = |
| new MiniClusterWithClientResource( |
| new MiniClusterResourceConfiguration.Builder() |
| .setConfiguration(conf) |
| .setNumberTaskManagers(numberTaskmanagers) |
| .setNumberSlotsPerTaskManager(slotsPerTM) |
| .build()); |
| miniCluster.before(); |
| final StreamExecutionEnvironment env = |
| StreamExecutionEnvironment.getExecutionEnvironment(conf); |
| settings.configure(env); |
| try { |
| waitForCleanShutdown(); |
| final CompletableFuture<JobSubmissionResult> result = |
| miniCluster.getMiniCluster().submitJob(streamGraph.getJobGraph()); |
| |
| checkCounters( |
| miniCluster |
| .getMiniCluster() |
| .requestJobResult(result.get().getJobID()) |
| .get() |
| .toJobExecutionResult(getClass().getClassLoader())); |
| } catch (Exception e) { |
| if (!ExceptionUtils.findThrowable(e, TestException.class).isPresent()) { |
| throw e; |
| } |
| } finally { |
| miniCluster.after(); |
| } |
| if (settings.generateCheckpoint) { |
| return Files.find(checkpointDir.toPath(), 2, this::isCompletedCheckpoint) |
| .max(Comparator.comparing(Path::toString)) |
| .map(Path::toFile) |
| .orElseThrow(() -> new IllegalStateException("Cannot generate checkpoint")); |
| } |
| return null; |
| } |
| |
| private boolean isCompletedCheckpoint(Path path, BasicFileAttributes attr) { |
| return attr.isDirectory() |
| && path.getFileName().toString().startsWith(CHECKPOINT_DIR_PREFIX) |
| && hasMetadata(path); |
| } |
| |
| private boolean hasMetadata(Path file) { |
| try { |
| return Files.find( |
| file.toAbsolutePath(), |
| 1, |
| (path, attrs) -> |
| path.getFileName().toString().equals(METADATA_FILE_NAME)) |
| .findAny() |
| .isPresent(); |
| } catch (IOException e) { |
| ExceptionUtils.rethrow(e); |
| return false; // should never happen |
| } |
| } |
| |
| private StreamGraph getStreamGraph(UnalignedSettings settings, Configuration conf) { |
| // a dummy environment used to retrieve the DAG, mini cluster will be used later |
| final StreamExecutionEnvironment setupEnv = |
| StreamExecutionEnvironment.createLocalEnvironment(conf); |
| settings.configure(setupEnv); |
| |
| settings.dagCreator.create( |
| setupEnv, |
| settings.minCheckpoints, |
| settings.channelType.slotSharing, |
| settings.expectedFailures - settings.failuresAfterSourceFinishes); |
| |
| return setupEnv.getStreamGraph(); |
| } |
| |
| private void waitForCleanShutdown() throws InterruptedException { |
| // direct memory in netty will be freed through gc/finalization |
| // too many successive executions will lead to OOM by netty |
| // slow down when half the memory is taken and wait for gc |
| if (PlatformDependent.usedDirectMemory() > PlatformDependent.maxDirectMemory() / 2) { |
| final Duration waitTime = Duration.ofSeconds(10); |
| Deadline deadline = Deadline.fromNow(waitTime); |
| while (PlatformDependent.usedDirectMemory() > 0 && deadline.hasTimeLeft()) { |
| System.gc(); |
| Thread.sleep(100); |
| } |
| final Duration timeLeft = deadline.timeLeft(); |
| if (timeLeft.isNegative()) { |
| LOG.warn( |
| "Waited 10s for clean shutdown of previous runs but there is still direct memory in use: " |
| + PlatformDependent.usedDirectMemory()); |
| } else { |
| LOG.info( |
| "Needed to wait {} ms for full cleanup of previous runs.", |
| waitTime.minus(timeLeft).toMillis()); |
| } |
| } |
| } |
| |
| protected abstract void checkCounters(JobExecutionResult result); |
| |
| /** A source that generates longs in a fixed number of splits. */ |
| protected static class LongSource |
| implements Source<Long, LongSource.LongSplit, LongSource.EnumeratorState> { |
| private final int minCheckpoints; |
| private final int numSplits; |
| private final int expectedRestarts; |
| private final long checkpointingInterval; |
| |
| protected LongSource( |
| int minCheckpoints, |
| int numSplits, |
| int expectedRestarts, |
| long checkpointingInterval) { |
| this.minCheckpoints = minCheckpoints; |
| this.numSplits = numSplits; |
| this.expectedRestarts = expectedRestarts; |
| this.checkpointingInterval = checkpointingInterval; |
| } |
| |
| @Override |
| public Boundedness getBoundedness() { |
| return Boundedness.CONTINUOUS_UNBOUNDED; |
| } |
| |
| @Override |
| public SourceReader<Long, LongSplit> createReader(SourceReaderContext readerContext) { |
| return new LongSourceReader( |
| readerContext.getIndexOfSubtask(), |
| minCheckpoints, |
| expectedRestarts, |
| checkpointingInterval); |
| } |
| |
| @Override |
| public SplitEnumerator<LongSplit, EnumeratorState> createEnumerator( |
| SplitEnumeratorContext<LongSplit> enumContext) { |
| List<LongSplit> splits = |
| IntStream.range(0, numSplits) |
| .mapToObj(i -> new LongSplit(i, numSplits)) |
| .collect(Collectors.toList()); |
| return new LongSplitSplitEnumerator(enumContext, new EnumeratorState(splits, 0, 0)); |
| } |
| |
| @Override |
| public SplitEnumerator<LongSplit, EnumeratorState> restoreEnumerator( |
| SplitEnumeratorContext<LongSplit> enumContext, EnumeratorState state) { |
| return new LongSplitSplitEnumerator(enumContext, state); |
| } |
| |
| @Override |
| public SimpleVersionedSerializer<LongSplit> getSplitSerializer() { |
| return new SplitVersionedSerializer(); |
| } |
| |
| @Override |
| public SimpleVersionedSerializer<EnumeratorState> getEnumeratorCheckpointSerializer() { |
| return new EnumeratorVersionedSerializer(); |
| } |
| |
| private static class LongSourceReader implements SourceReader<Long, LongSplit> { |
| private final int subtaskIndex; |
| private final int minCheckpoints; |
| private final int expectedRestarts; |
| private final LongCounter numInputsCounter = new LongCounter(); |
| private final List<LongSplit> splits = new ArrayList<>(); |
| private final Duration pumpInterval; |
| private int numAbortedCheckpoints; |
| private int numRestarts; |
| private int numCompletedCheckpoints; |
| private boolean finishing; |
| private boolean recovered; |
| @Nullable private Deadline pumpingUntil = null; |
| |
| public LongSourceReader( |
| int subtaskIndex, |
| int minCheckpoints, |
| int expectedRestarts, |
| long checkpointingInterval) { |
| this.subtaskIndex = subtaskIndex; |
| this.minCheckpoints = minCheckpoints; |
| this.expectedRestarts = expectedRestarts; |
| pumpInterval = Duration.ofMillis(checkpointingInterval); |
| } |
| |
| @Override |
| public void start() {} |
| |
| @Override |
| public InputStatus pollNext(ReaderOutput<Long> output) throws InterruptedException { |
| for (LongSplit split : splits) { |
| output.collect(withHeader(split.nextNumber), split.nextNumber); |
| split.nextNumber += split.increment; |
| } |
| |
| if (finishing) { |
| return InputStatus.END_OF_INPUT; |
| } |
| |
| if (pumpingUntil != null && pumpingUntil.isOverdue()) { |
| pumpingUntil = null; |
| } |
| if (pumpingUntil == null) { |
| Thread.sleep(1); |
| } |
| return InputStatus.MORE_AVAILABLE; |
| } |
| |
| @Override |
| public List<LongSplit> snapshotState(long checkpointId) { |
| LOG.info( |
| "Snapshotted {} @ {} subtask ({} attempt)", |
| splits, |
| subtaskIndex, |
| numRestarts); |
| // barrier passed, so no need to add more data for this test |
| pumpingUntil = null; |
| return splits; |
| } |
| |
| @Override |
| public void notifyCheckpointComplete(long checkpointId) { |
| LOG.info( |
| "notifyCheckpointComplete {} @ {} subtask ({} attempt)", |
| numCompletedCheckpoints, |
| subtaskIndex, |
| numRestarts); |
| // Update polling state before final checkpoint such that if there is an issue |
| // during finishing, after recovery the source immediately starts finishing |
| // again. In this way, we avoid a deadlock where some tasks need another |
| // checkpoint completed, while some tasks are finishing (and thus there are no |
| // new checkpoint). |
| updatePollingState(); |
| numCompletedCheckpoints++; |
| recovered = true; |
| numAbortedCheckpoints = 0; |
| } |
| |
| @Override |
| public void notifyCheckpointAborted(long checkpointId) { |
| if (numAbortedCheckpoints++ > 100) { |
| // aborted too many checkpoints in a row, which usually indicates that part of |
| // the pipeline is already completed |
| // here simply also advance completed checkpoints to avoid running into a live |
| // lock |
| numCompletedCheckpoints = minCheckpoints + 1; |
| } |
| updatePollingState(); |
| } |
| |
| @Override |
| public CompletableFuture<Void> isAvailable() { |
| return FutureUtils.completedVoidFuture(); |
| } |
| |
| @Override |
| public void addSplits(List<LongSplit> splits) { |
| this.splits.addAll(splits); |
| updatePollingState(); |
| LOG.info( |
| "Added splits {}, finishing={}, pumping until {} @ {} subtask ({} attempt)", |
| splits, |
| finishing, |
| pumpingUntil, |
| subtaskIndex, |
| numRestarts); |
| } |
| |
| @Override |
| public void notifyNoMoreSplits() { |
| updatePollingState(); |
| } |
| |
| private void updatePollingState() { |
| if (numCompletedCheckpoints >= minCheckpoints && numRestarts >= expectedRestarts) { |
| finishing = true; |
| LOG.info("Finishing @ {} subtask ({} attempt)", subtaskIndex, numRestarts); |
| } else if (recovered) { |
| // a successful checkpoint as a proxy for a finished recovery |
| // cause backpressure until next checkpoint is added |
| pumpingUntil = Deadline.fromNow(pumpInterval); |
| LOG.info( |
| "Pumping until {} @ {} subtask ({} attempt)", |
| pumpingUntil, |
| subtaskIndex, |
| numRestarts); |
| } |
| } |
| |
| @Override |
| public void handleSourceEvents(SourceEvent sourceEvent) { |
| if (sourceEvent instanceof SyncEvent) { |
| numRestarts = ((SyncEvent) sourceEvent).numRestarts; |
| numCompletedCheckpoints = ((SyncEvent) sourceEvent).numCheckpoints; |
| LOG.info( |
| "Set restarts={}, numCompletedCheckpoints={} @ {} subtask ({} attempt)", |
| numRestarts, |
| numCompletedCheckpoints, |
| subtaskIndex, |
| numRestarts); |
| updatePollingState(); |
| } |
| } |
| |
| @Override |
| public void close() throws Exception { |
| for (LongSplit split : splits) { |
| numInputsCounter.add(split.nextNumber / split.increment); |
| } |
| } |
| } |
| |
| private static class SyncEvent implements SourceEvent { |
| final int numRestarts; |
| final int numCheckpoints; |
| |
| SyncEvent(int numRestarts, int numCheckpoints) { |
| this.numRestarts = numRestarts; |
| this.numCheckpoints = numCheckpoints; |
| } |
| } |
| |
| private static class LongSplit implements SourceSplit { |
| private final int increment; |
| private long nextNumber; |
| |
| public LongSplit(long nextNumber, int increment) { |
| this.nextNumber = nextNumber; |
| this.increment = increment; |
| } |
| |
| public int getBaseNumber() { |
| return (int) (nextNumber % increment); |
| } |
| |
| @Override |
| public String splitId() { |
| return String.valueOf(increment); |
| } |
| |
| @Override |
| public String toString() { |
| return "LongSplit{" + "increment=" + increment + ", nextNumber=" + nextNumber + '}'; |
| } |
| } |
| |
| private static class LongSplitSplitEnumerator |
| implements SplitEnumerator<LongSplit, EnumeratorState> { |
| private final SplitEnumeratorContext<LongSplit> context; |
| private final EnumeratorState state; |
| private final Map<Integer, Integer> subtaskRestarts = new HashMap<>(); |
| |
| private LongSplitSplitEnumerator( |
| SplitEnumeratorContext<LongSplit> context, EnumeratorState state) { |
| this.context = context; |
| this.state = state; |
| } |
| |
| @Override |
| public void start() {} |
| |
| @Override |
| public void handleSplitRequest(int subtaskId, @Nullable String requesterHostname) {} |
| |
| @Override |
| public void addSplitsBack(List<LongSplit> splits, int subtaskId) { |
| LOG.info("addSplitsBack {}", splits); |
| // Called on recovery |
| subtaskRestarts.compute( |
| subtaskId, |
| (id, oldCount) -> oldCount == null ? state.numRestarts + 1 : oldCount + 1); |
| state.unassignedSplits.addAll(splits); |
| } |
| |
| @Override |
| public void addReader(int subtaskId) { |
| if (context.registeredReaders().size() == context.currentParallelism()) { |
| if (!state.unassignedSplits.isEmpty()) { |
| Map<Integer, List<LongSplit>> assignment = |
| state.unassignedSplits.stream() |
| .collect(Collectors.groupingBy(LongSplit::getBaseNumber)); |
| LOG.info("Assigning splits {}", assignment); |
| context.assignSplits(new SplitsAssignment<>(assignment)); |
| state.unassignedSplits.clear(); |
| } |
| context.registeredReaders().keySet().forEach(context::signalNoMoreSplits); |
| Optional<Integer> restarts = |
| subtaskRestarts.values().stream().max(Comparator.naturalOrder()); |
| if (restarts.isPresent() && restarts.get() > state.numRestarts) { |
| state.numRestarts = restarts.get(); |
| // Implicitly sync the restart count of all subtasks with state.numRestarts |
| subtaskRestarts.clear(); |
| final SyncEvent event = |
| new SyncEvent(state.numRestarts, state.numCompletedCheckpoints); |
| context.registeredReaders() |
| .keySet() |
| .forEach(index -> context.sendEventToSourceReader(index, event)); |
| } |
| } |
| } |
| |
| @Override |
| public void notifyCheckpointComplete(long checkpointId) { |
| state.numCompletedCheckpoints++; |
| } |
| |
| @Override |
| public EnumeratorState snapshotState(long checkpointId) throws Exception { |
| LOG.info("snapshotState {}", state); |
| return state; |
| } |
| |
| @Override |
| public void close() throws IOException {} |
| } |
| |
| private static class EnumeratorState { |
| final List<LongSplit> unassignedSplits; |
| int numRestarts; |
| int numCompletedCheckpoints; |
| |
| public EnumeratorState( |
| List<LongSplit> unassignedSplits, |
| int numRestarts, |
| int numCompletedCheckpoints) { |
| this.unassignedSplits = unassignedSplits; |
| this.numRestarts = numRestarts; |
| this.numCompletedCheckpoints = numCompletedCheckpoints; |
| } |
| |
| @Override |
| public String toString() { |
| return "EnumeratorState{" |
| + "unassignedSplits=" |
| + unassignedSplits |
| + ", numRestarts=" |
| + numRestarts |
| + ", numCompletedCheckpoints=" |
| + numCompletedCheckpoints |
| + '}'; |
| } |
| } |
| |
| private static class EnumeratorVersionedSerializer |
| implements SimpleVersionedSerializer<EnumeratorState> { |
| private final SplitVersionedSerializer splitVersionedSerializer = |
| new SplitVersionedSerializer(); |
| |
| @Override |
| public int getVersion() { |
| return 0; |
| } |
| |
| @Override |
| public byte[] serialize(EnumeratorState state) { |
| final ByteBuffer byteBuffer = |
| ByteBuffer.allocate( |
| state.unassignedSplits.size() * SplitVersionedSerializer.LENGTH |
| + 8); |
| byteBuffer.putInt(state.numRestarts); |
| byteBuffer.putInt(state.numCompletedCheckpoints); |
| for (final LongSplit unassignedSplit : state.unassignedSplits) { |
| byteBuffer.put(splitVersionedSerializer.serialize(unassignedSplit)); |
| } |
| return byteBuffer.array(); |
| } |
| |
| @Override |
| public EnumeratorState deserialize(int version, byte[] serialized) { |
| final ByteBuffer byteBuffer = ByteBuffer.wrap(serialized); |
| final int numRestarts = byteBuffer.getInt(); |
| final int numCompletedCheckpoints = byteBuffer.getInt(); |
| |
| final List<LongSplit> splits = |
| new ArrayList<>(serialized.length / SplitVersionedSerializer.LENGTH); |
| |
| final byte[] serializedSplit = new byte[SplitVersionedSerializer.LENGTH]; |
| while (byteBuffer.hasRemaining()) { |
| byteBuffer.get(serializedSplit); |
| splits.add(splitVersionedSerializer.deserialize(version, serializedSplit)); |
| } |
| return new EnumeratorState(splits, numRestarts, numCompletedCheckpoints); |
| } |
| } |
| |
| private static class SplitVersionedSerializer |
| implements SimpleVersionedSerializer<LongSplit> { |
| static final int LENGTH = 16; |
| |
| @Override |
| public int getVersion() { |
| return 0; |
| } |
| |
| @Override |
| public byte[] serialize(LongSplit split) { |
| final byte[] bytes = new byte[LENGTH]; |
| ByteBuffer.wrap(bytes).putLong(split.nextNumber).putInt(split.increment); |
| return bytes; |
| } |
| |
| @Override |
| public LongSplit deserialize(int version, byte[] serialized) { |
| final ByteBuffer byteBuffer = ByteBuffer.wrap(serialized); |
| return new LongSplit(byteBuffer.getLong(), byteBuffer.getInt()); |
| } |
| } |
| } |
| |
| interface DagCreator { |
| void create( |
| StreamExecutionEnvironment environment, |
| int minCheckpoints, |
| boolean slotSharing, |
| int expectedFailuresUntilSourceFinishes); |
| } |
| |
| /** Which channels are used to connect the tasks. */ |
| protected enum ChannelType { |
| LOCAL(true, n -> 1), |
| REMOTE(false, n -> n), |
| MIXED(true, n -> Math.min(n, 3)); |
| |
| final boolean slotSharing; |
| final Function<Integer, Integer> slotsToTaskManagers; |
| |
| ChannelType(boolean slotSharing, Function<Integer, Integer> slotsToTaskManagers) { |
| this.slotSharing = slotSharing; |
| this.slotsToTaskManagers = slotsToTaskManagers; |
| } |
| |
| @Override |
| public String toString() { |
| return name().toLowerCase(); |
| } |
| } |
| |
| /** Builder-like interface for all relevant unaligned settings. */ |
| protected static class UnalignedSettings { |
| private int parallelism; |
| private final int minCheckpoints = 10; |
| @Nullable private File restoreCheckpoint; |
| private boolean generateCheckpoint = false; |
| int expectedFailures = 0; |
| int tolerableCheckpointFailures = 0; |
| private final DagCreator dagCreator; |
| private int alignmentTimeout = 0; |
| private Duration checkpointTimeout = CHECKPOINTING_TIMEOUT.defaultValue(); |
| private int failuresAfterSourceFinishes = 0; |
| private ChannelType channelType = ChannelType.MIXED; |
| private int buffersPerChannel = 1; |
| |
| public UnalignedSettings(DagCreator dagCreator) { |
| this.dagCreator = dagCreator; |
| } |
| |
| public UnalignedSettings setParallelism(int parallelism) { |
| this.parallelism = parallelism; |
| return this; |
| } |
| |
| public UnalignedSettings setRestoreCheckpoint(File restoreCheckpoint) { |
| this.restoreCheckpoint = restoreCheckpoint; |
| return this; |
| } |
| |
| public UnalignedSettings setGenerateCheckpoint(boolean generateCheckpoint) { |
| this.generateCheckpoint = generateCheckpoint; |
| return this; |
| } |
| |
| public UnalignedSettings setExpectedFailures(int expectedFailures) { |
| this.expectedFailures = expectedFailures; |
| return this; |
| } |
| |
| public UnalignedSettings setCheckpointTimeout(Duration checkpointTimeout) { |
| this.checkpointTimeout = checkpointTimeout; |
| return this; |
| } |
| |
| public UnalignedSettings setAlignmentTimeout(int alignmentTimeout) { |
| this.alignmentTimeout = alignmentTimeout; |
| return this; |
| } |
| |
| public UnalignedSettings setFailuresAfterSourceFinishes(int failuresAfterSourceFinishes) { |
| this.failuresAfterSourceFinishes = failuresAfterSourceFinishes; |
| return this; |
| } |
| |
| public UnalignedSettings setChannelTypes(ChannelType channelType) { |
| this.channelType = channelType; |
| return this; |
| } |
| |
| public UnalignedSettings setTolerableCheckpointFailures(int tolerableCheckpointFailures) { |
| this.tolerableCheckpointFailures = tolerableCheckpointFailures; |
| return this; |
| } |
| |
| public UnalignedSettings setBuffersPerChannel(int buffersPerChannel) { |
| this.buffersPerChannel = buffersPerChannel; |
| return this; |
| } |
| |
| public void configure(StreamExecutionEnvironment env) { |
| env.enableCheckpointing(Math.max(100L, parallelism * 50L)); |
| env.getCheckpointConfig().setAlignmentTimeout(Duration.ofMillis(alignmentTimeout)); |
| env.getCheckpointConfig().setCheckpointTimeout(checkpointTimeout.toMillis()); |
| env.getCheckpointConfig() |
| .setTolerableCheckpointFailureNumber(tolerableCheckpointFailures); |
| env.setParallelism(parallelism); |
| env.setRestartStrategy( |
| RestartStrategies.fixedDelayRestart( |
| generateCheckpoint ? expectedFailures / 2 : expectedFailures, |
| Time.milliseconds(100))); |
| env.getCheckpointConfig().enableUnalignedCheckpoints(true); |
| // for custom partitioner |
| env.getCheckpointConfig().setForceUnalignedCheckpoints(true); |
| if (generateCheckpoint) { |
| env.getCheckpointConfig() |
| .enableExternalizedCheckpoints( |
| CheckpointConfig.ExternalizedCheckpointCleanup |
| .RETAIN_ON_CANCELLATION); |
| } |
| } |
| |
| public Configuration getConfiguration(File checkpointDir) { |
| Configuration conf = new Configuration(); |
| |
| conf.setFloat(TaskManagerOptions.NETWORK_MEMORY_FRACTION, 0.9f); |
| conf.set(TaskManagerOptions.MEMORY_SEGMENT_SIZE, MemorySize.parse("4kb")); |
| conf.setString(StateBackendOptions.STATE_BACKEND, "filesystem"); |
| conf.setString( |
| CheckpointingOptions.CHECKPOINTS_DIRECTORY, checkpointDir.toURI().toString()); |
| if (restoreCheckpoint != null) { |
| conf.set( |
| SavepointConfigOptions.SAVEPOINT_PATH, |
| restoreCheckpoint.toURI().toString()); |
| } |
| |
| conf.set(NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_PER_CHANNEL, buffersPerChannel); |
| conf.set(NettyShuffleEnvironmentOptions.NETWORK_REQUEST_BACKOFF_MAX, 60000); |
| conf.set(AkkaOptions.ASK_TIMEOUT_DURATION, Duration.ofMinutes(1)); |
| return conf; |
| } |
| |
| @Override |
| public String toString() { |
| return "UnalignedSettings{" |
| + "parallelism=" |
| + parallelism |
| + ", minCheckpoints=" |
| + minCheckpoints |
| + ", restoreCheckpoint=" |
| + restoreCheckpoint |
| + ", generateCheckpoint=" |
| + generateCheckpoint |
| + ", expectedFailures=" |
| + expectedFailures |
| + ", dagCreator=" |
| + dagCreator |
| + ", alignmentTimeout=" |
| + alignmentTimeout |
| + ", failuresAfterSourceFinishes=" |
| + failuresAfterSourceFinishes |
| + ", channelType=" |
| + channelType |
| + '}'; |
| } |
| } |
| |
| /** Shifts the partitions one up. */ |
| protected static class ShiftingPartitioner implements Partitioner<Long> { |
| @Override |
| public int partition(Long key, int numPartitions) { |
| return (int) ((withoutHeader(key) + 1) % numPartitions); |
| } |
| } |
| |
| /** Distributes chunks of the size of numPartitions in a round robin fashion. */ |
| protected static class ChunkDistributingPartitioner implements Partitioner<Long> { |
| @Override |
| public int partition(Long key, int numPartitions) { |
| return (int) ((withoutHeader(key) / numPartitions) % numPartitions); |
| } |
| } |
| |
| /** A mapper that fails in particular situations/attempts. */ |
| protected static class FailingMapper extends RichMapFunction<Long, Long> |
| implements CheckpointedFunction, CheckpointListener { |
| private static final ListStateDescriptor<FailingMapperState> |
| FAILING_MAPPER_STATE_DESCRIPTOR = |
| new ListStateDescriptor<>("state", FailingMapperState.class); |
| private ListState<FailingMapperState> listState; |
| @Nullable private transient FailingMapperState state; |
| private final FilterFunction<FailingMapperState> failDuringMap; |
| private final FilterFunction<FailingMapperState> failDuringSnapshot; |
| private final FilterFunction<FailingMapperState> failDuringRecovery; |
| private final FilterFunction<FailingMapperState> failDuringClose; |
| private long lastValue; |
| |
| protected FailingMapper( |
| FilterFunction<FailingMapperState> failDuringMap, |
| FilterFunction<FailingMapperState> failDuringSnapshot, |
| FilterFunction<FailingMapperState> failDuringRecovery, |
| FilterFunction<FailingMapperState> failDuringClose) { |
| this.failDuringMap = failDuringMap; |
| this.failDuringSnapshot = failDuringSnapshot; |
| this.failDuringRecovery = failDuringRecovery; |
| this.failDuringClose = failDuringClose; |
| } |
| |
| @Override |
| public Long map(Long value) throws Exception { |
| lastValue = withoutHeader(value); |
| checkFail(failDuringMap, "map"); |
| return value; |
| } |
| |
| public void checkFail(FilterFunction<FailingMapperState> failFunction, String description) |
| throws Exception { |
| if (state != null && failFunction.filter(state)) { |
| failMapper(description); |
| } |
| } |
| |
| private void failMapper(String description) throws Exception { |
| throw new TestException( |
| "Failing " |
| + description |
| + " @ " |
| + state.completedCheckpoints |
| + " (" |
| + state.runNumber |
| + " attempt); last value " |
| + lastValue); |
| } |
| |
| @Override |
| public void notifyCheckpointComplete(long checkpointId) { |
| if (state != null) { |
| state.completedCheckpoints++; |
| } |
| } |
| |
| @Override |
| public void notifyCheckpointAborted(long checkpointId) {} |
| |
| @Override |
| public void snapshotState(FunctionSnapshotContext context) throws Exception { |
| checkFail(failDuringSnapshot, "snapshotState"); |
| listState.clear(); |
| if (state != null) { |
| listState.add(state); |
| } |
| } |
| |
| @Override |
| public void close() throws Exception { |
| checkFail(failDuringClose, "close"); |
| super.close(); |
| } |
| |
| @Override |
| public void initializeState(FunctionInitializationContext context) throws Exception { |
| listState = |
| context.getOperatorStateStore().getListState(FAILING_MAPPER_STATE_DESCRIPTOR); |
| if (getRuntimeContext().getIndexOfThisSubtask() == 0) { |
| state = Iterables.get(listState.get(), 0, new FailingMapperState(0, 0)); |
| state.runNumber = getRuntimeContext().getAttemptNumber(); |
| } |
| checkFail(failDuringRecovery, "initializeState"); |
| } |
| |
| /** State for {@link FailingMapper}. */ |
| protected static class FailingMapperState { |
| protected long completedCheckpoints; |
| protected long runNumber; |
| |
| protected FailingMapperState(long completedCheckpoints, long runNumber) { |
| this.completedCheckpoints = completedCheckpoints; |
| this.runNumber = runNumber; |
| } |
| } |
| } |
| |
| /** Base for state of the a specific {@link VerifyingSinkBase}. */ |
| public static class VerifyingSinkStateBase { |
| protected long numOutOfOrderness; |
| protected long numLostValues; |
| protected long numDuplicates; |
| protected long numOutput = 0; |
| protected long completedCheckpoints; |
| |
| @Override |
| public String toString() { |
| return "StateBase{" |
| + "numOutOfOrderness=" |
| + numOutOfOrderness |
| + ", numLostValues=" |
| + numLostValues |
| + ", numDuplicates=" |
| + numDuplicates |
| + ", numOutput=" |
| + numOutput |
| + ", completedCheckpoints=" |
| + completedCheckpoints |
| + '}'; |
| } |
| } |
| |
| /** |
| * A sink that checks if the members arrive in the expected order without any missing values. |
| */ |
| protected abstract static class VerifyingSinkBase<State extends VerifyingSinkStateBase> |
| extends RichSinkFunction<Long> implements CheckpointedFunction, CheckpointListener { |
| private final LongCounter numOutputCounter = new LongCounter(); |
| private final LongCounter outOfOrderCounter = new LongCounter(); |
| private final LongCounter lostCounter = new LongCounter(); |
| private final LongCounter duplicatesCounter = new LongCounter(); |
| private final IntCounter numFailures = new IntCounter(); |
| private final Duration backpressureInterval; |
| private ListState<State> stateList; |
| protected transient State state; |
| protected final long minCheckpoints; |
| private boolean recovered; |
| @Nullable private Deadline backpressureUntil; |
| |
| protected VerifyingSinkBase(long minCheckpoints, long checkpointingInterval) { |
| this.minCheckpoints = minCheckpoints; |
| this.backpressureInterval = Duration.ofMillis(checkpointingInterval); |
| } |
| |
| @Override |
| public void open(Configuration parameters) throws Exception { |
| super.open(parameters); |
| getRuntimeContext().addAccumulator(NUM_OUTPUTS, numOutputCounter); |
| getRuntimeContext().addAccumulator(NUM_OUT_OF_ORDER, outOfOrderCounter); |
| getRuntimeContext().addAccumulator(NUM_DUPLICATES, duplicatesCounter); |
| getRuntimeContext().addAccumulator(NUM_LOST, lostCounter); |
| getRuntimeContext().addAccumulator(NUM_FAILURES, numFailures); |
| } |
| |
| @Override |
| public void initializeState(FunctionInitializationContext context) throws Exception { |
| final State state = createState(); |
| stateList = |
| context.getOperatorStateStore() |
| .getListState( |
| new ListStateDescriptor<>( |
| "state", (Class<State>) state.getClass())); |
| this.state = getOnlyElement(stateList.get(), state); |
| LOG.info( |
| "Inducing no backpressure @ {} subtask ({} attempt)", |
| getRuntimeContext().getIndexOfThisSubtask(), |
| getRuntimeContext().getAttemptNumber()); |
| } |
| |
| protected abstract State createState(); |
| |
| protected void induceBackpressure() throws InterruptedException { |
| if (backpressureUntil != null) { |
| // induce heavy backpressure until enough checkpoints have been written |
| Thread.sleep(1); |
| if (backpressureUntil.isOverdue()) { |
| backpressureUntil = null; |
| } |
| } |
| // after all checkpoints have been completed, the remaining data should be flushed out |
| // fairly quickly |
| } |
| |
| @Override |
| public void snapshotState(FunctionSnapshotContext context) throws Exception { |
| stateList.clear(); |
| stateList.add(state); |
| if (recovered) { |
| backpressureUntil = Deadline.fromNow(backpressureInterval); |
| } |
| } |
| |
| @Override |
| public void notifyCheckpointComplete(long checkpointId) { |
| recovered = true; |
| state.completedCheckpoints++; |
| if (state.completedCheckpoints < minCheckpoints) { |
| this.backpressureUntil = Deadline.fromNow(backpressureInterval); |
| LOG.info( |
| "Inducing backpressure until {} @ {} subtask ({} attempt)", |
| backpressureUntil, |
| getRuntimeContext().getIndexOfThisSubtask(), |
| getRuntimeContext().getAttemptNumber()); |
| } else { |
| this.backpressureUntil = null; |
| LOG.info( |
| "Inducing no backpressure @ {} subtask ({} attempt)", |
| getRuntimeContext().getIndexOfThisSubtask(), |
| getRuntimeContext().getAttemptNumber()); |
| } |
| } |
| |
| @Override |
| public void close() throws Exception { |
| numOutputCounter.add(state.numOutput); |
| outOfOrderCounter.add(state.numOutOfOrderness); |
| duplicatesCounter.add(state.numDuplicates); |
| lostCounter.add(state.numLostValues); |
| if (getRuntimeContext().getIndexOfThisSubtask() == 0) { |
| numFailures.add(getRuntimeContext().getAttemptNumber()); |
| } |
| LOG.info( |
| "Last state {} @ {} subtask ({} attempt)", |
| state, |
| getRuntimeContext().getIndexOfThisSubtask(), |
| getRuntimeContext().getAttemptNumber()); |
| super.close(); |
| } |
| } |
| |
| static class MinEmittingFunction extends RichCoFlatMapFunction<Long, Long, Long> |
| implements CheckpointedFunction { |
| private ListState<State> stateList; |
| private State state; |
| |
| @Override |
| public void snapshotState(FunctionSnapshotContext context) throws Exception { |
| stateList.clear(); |
| stateList.add(state); |
| } |
| |
| @Override |
| public void initializeState(FunctionInitializationContext context) throws Exception { |
| stateList = |
| context.getOperatorStateStore() |
| .getListState(new ListStateDescriptor<>("state", State.class)); |
| state = getOnlyElement(stateList.get(), new State()); |
| } |
| |
| @Override |
| public void flatMap1(Long value, Collector<Long> out) { |
| long baseValue = withoutHeader(value); |
| state.lastLeft = baseValue; |
| if (state.lastRight >= baseValue) { |
| out.collect(value); |
| } |
| } |
| |
| @Override |
| public void flatMap2(Long value, Collector<Long> out) { |
| long baseValue = withoutHeader(value); |
| state.lastRight = baseValue; |
| if (state.lastLeft >= baseValue) { |
| out.collect(value); |
| } |
| } |
| |
| private static class State { |
| private long lastLeft = Long.MIN_VALUE; |
| private long lastRight = Long.MIN_VALUE; |
| } |
| } |
| |
| protected static long withHeader(long value) { |
| checkState( |
| value <= Integer.MAX_VALUE, |
| "Value too large for header, this indicates that the test is running too long."); |
| return value ^ HEADER; |
| } |
| |
| protected static long withoutHeader(long value) { |
| checkHeader(value); |
| return value ^ HEADER; |
| } |
| |
| protected static long checkHeader(long value) { |
| if ((value & HEADER_MASK) != HEADER) { |
| throw new IllegalArgumentException( |
| "Stream corrupted. Cannot find the header " |
| + Long.toHexString(HEADER) |
| + " in the value " |
| + Long.toHexString(value)); |
| } |
| return value; |
| } |
| |
| private static class TestException extends Exception { |
| public TestException(String s) { |
| super(s); |
| } |
| } |
| } |