| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.streaming.tests; |
| |
| import org.apache.flink.api.common.functions.RichFlatMapFunction; |
| import org.apache.flink.api.common.functions.RuntimeContext; |
| import org.apache.flink.api.common.restartstrategy.RestartStrategies; |
| import org.apache.flink.api.common.state.CheckpointListener; |
| import org.apache.flink.api.common.state.ListState; |
| import org.apache.flink.api.common.state.ListStateDescriptor; |
| import org.apache.flink.api.common.state.ValueState; |
| import org.apache.flink.api.common.state.ValueStateDescriptor; |
| import org.apache.flink.api.java.functions.KeySelector; |
| import org.apache.flink.api.java.utils.ParameterTool; |
| import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend; |
| import org.apache.flink.runtime.state.FunctionInitializationContext; |
| import org.apache.flink.runtime.state.FunctionSnapshotContext; |
| import org.apache.flink.runtime.state.hashmap.HashMapStateBackend; |
| import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; |
| import org.apache.flink.streaming.api.environment.CheckpointConfig; |
| import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; |
| import org.apache.flink.streaming.api.functions.sink.PrintSinkFunction; |
| import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; |
| import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; |
| import org.apache.flink.util.Collector; |
| import org.apache.flink.util.Preconditions; |
| |
| import org.apache.commons.lang3.RandomStringUtils; |
| |
| import java.io.IOException; |
| import java.io.Serializable; |
| import java.util.ArrayList; |
| import java.util.HashSet; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Set; |
| |
| /** |
| * Automatic end-to-end test for local recovery (including sticky allocation). |
| * |
| * <p>List of possible input parameters for this job: |
| * |
| * <ul> |
| * <li>checkpointDir: the checkpoint directory, required. |
| * <li>parallelism: the parallelism of the job, default 1. |
| * <li>maxParallelism: the maximum parallelism of the job, default 1. |
| * <li>checkpointInterval: the checkpointing interval in milliseconds, default 1000. |
| * <li>restartDelay: the delay of the fixed delay restart strategy, default 0. |
| * <li>externalizedCheckpoints: flag to activate externalized checkpoints, default <code>false |
| * </code>. |
| * <li>stateBackend: choice for state backend between <code>file</code> and <code>rocks</code>, |
| * default <code>file</code>. |
| * <li>killJvmOnFail: flag that determines whether or not an artificial failure induced by the |
| * test kills the JVM or not. |
| * <li>incrementalCheckpoints: flag for incremental checkpoint with rocks state backend, default |
| * <code>false</code>. |
| * <li>delay: sleep delay to throttle down the production of the source, default 0. |
| * <li>maxAttempts: the maximum number of run attempts, before the job finishes with success, |
| * default 3. |
| * <li>valueSize: size of the artificial value for each key in bytes, default 10. |
| * </ul> |
| */ |
| public class StickyAllocationAndLocalRecoveryTestJob { |
| |
| public static void main(String[] args) throws Exception { |
| |
| final ParameterTool pt = ParameterTool.fromArgs(args); |
| |
| final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); |
| |
| env.setParallelism(pt.getInt("parallelism", 1)); |
| env.setMaxParallelism(pt.getInt("maxParallelism", pt.getInt("parallelism", 1))); |
| env.enableCheckpointing(pt.getInt("checkpointInterval", 1000)); |
| env.setRestartStrategy( |
| RestartStrategies.fixedDelayRestart( |
| Integer.MAX_VALUE, pt.getInt("restartDelay", 0))); |
| if (pt.getBoolean("externalizedCheckpoints", false)) { |
| env.getCheckpointConfig() |
| .enableExternalizedCheckpoints( |
| CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION); |
| } |
| |
| String checkpointDir = pt.getRequired("checkpointDir"); |
| env.getCheckpointConfig().setCheckpointStorage(checkpointDir); |
| |
| boolean killJvmOnFail = pt.getBoolean("killJvmOnFail", false); |
| |
| String stateBackend = pt.get("stateBackend", "hashmap"); |
| if ("hashmap".equals(stateBackend)) { |
| env.setStateBackend(new HashMapStateBackend()); |
| } else if ("rocks".equals(stateBackend)) { |
| boolean incrementalCheckpoints = pt.getBoolean("incrementalCheckpoints", false); |
| env.setStateBackend(new EmbeddedRocksDBStateBackend(incrementalCheckpoints)); |
| } else { |
| throw new IllegalArgumentException("Unknown backend: " + stateBackend); |
| } |
| |
| // make parameters available in the web interface |
| env.getConfig().setGlobalJobParameters(pt); |
| |
| // delay to throttle down the production of the source |
| long delay = pt.getLong("delay", 0L); |
| |
| // the maximum number of attempts, before the job finishes with success |
| int maxAttempts = pt.getInt("maxAttempts", 3); |
| |
| // size of one artificial value |
| int valueSize = pt.getInt("valueSize", 10); |
| |
| env.addSource(new RandomLongSource(maxAttempts, delay)) |
| .keyBy((KeySelector<Long, Long>) aLong -> aLong) |
| .flatMap(new StateCreatingFlatMap(valueSize, killJvmOnFail)) |
| .addSink(new PrintSinkFunction<>()); |
| |
| env.execute("Sticky Allocation And Local Recovery Test"); |
| } |
| |
| /** Source function that produces a long sequence. */ |
| private static final class RandomLongSource extends RichParallelSourceFunction<Long> |
| implements CheckpointedFunction { |
| |
| private static final long serialVersionUID = 1L; |
| |
| /** Generator delay between two events. */ |
| final long delay; |
| |
| /** Maximum restarts before shutting down this source. */ |
| final int maxAttempts; |
| |
| /** State that holds the current key for recovery. */ |
| transient ListState<Long> sourceCurrentKeyState; |
| |
| /** Generator's current key. */ |
| long currentKey; |
| |
| /** Generator runs while this is true. */ |
| volatile boolean running; |
| |
| RandomLongSource(int maxAttempts, long delay) { |
| this.delay = delay; |
| this.maxAttempts = maxAttempts; |
| this.running = true; |
| } |
| |
| @Override |
| public void run(SourceContext<Long> sourceContext) throws Exception { |
| |
| int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); |
| int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask(); |
| |
| // the source emits one final event and shuts down once we have reached max attempts. |
| if (getRuntimeContext().getAttemptNumber() > maxAttempts) { |
| synchronized (sourceContext.getCheckpointLock()) { |
| sourceContext.collect(Long.MAX_VALUE - subtaskIdx); |
| } |
| return; |
| } |
| |
| while (running) { |
| |
| synchronized (sourceContext.getCheckpointLock()) { |
| sourceContext.collect(currentKey); |
| currentKey += numberOfParallelSubtasks; |
| } |
| |
| if (delay > 0) { |
| Thread.sleep(delay); |
| } |
| } |
| } |
| |
| @Override |
| public void cancel() { |
| running = false; |
| } |
| |
| @Override |
| public void snapshotState(FunctionSnapshotContext context) throws Exception { |
| sourceCurrentKeyState.clear(); |
| sourceCurrentKeyState.add(currentKey); |
| } |
| |
| @Override |
| public void initializeState(FunctionInitializationContext context) throws Exception { |
| |
| ListStateDescriptor<Long> currentKeyDescriptor = |
| new ListStateDescriptor<>("currentKey", Long.class); |
| sourceCurrentKeyState = |
| context.getOperatorStateStore().getListState(currentKeyDescriptor); |
| |
| currentKey = getRuntimeContext().getIndexOfThisSubtask(); |
| Iterable<Long> iterable = sourceCurrentKeyState.get(); |
| if (iterable != null) { |
| Iterator<Long> iterator = iterable.iterator(); |
| if (iterator.hasNext()) { |
| currentKey = iterator.next(); |
| Preconditions.checkState(!iterator.hasNext()); |
| } |
| } |
| } |
| } |
| |
| /** Stateful map function. Failure creation and checks happen here. */ |
| private static final class StateCreatingFlatMap extends RichFlatMapFunction<Long, String> |
| implements CheckpointedFunction, CheckpointListener { |
| |
| private static final long serialVersionUID = 1L; |
| |
| /** User configured size of the generated artificial values in the keyed state. */ |
| final int valueSize; |
| |
| /** Holds the user configuration if the artificial test failure is killing the JVM. */ |
| final boolean killTaskOnFailure; |
| |
| /** This state is used to create artificial keyed state in the backend. */ |
| transient ValueState<String> valueState; |
| |
| /** This state is used to persist the schedulingAndFailureInfo to state. */ |
| transient ListState<MapperSchedulingAndFailureInfo> schedulingAndFailureState; |
| |
| /** This contains the current scheduling and failure meta data. */ |
| transient MapperSchedulingAndFailureInfo currentSchedulingAndFailureInfo; |
| |
| /** Message to indicate that recovery detected a failure with sticky allocation. */ |
| transient volatile String allocationFailureMessage; |
| |
| /** |
| * If this flag is true, the next invocation of the map function introduces a test failure. |
| */ |
| transient volatile boolean failTask; |
| |
| StateCreatingFlatMap(int valueSize, boolean killTaskOnFailure) { |
| this.valueSize = valueSize; |
| this.failTask = false; |
| this.killTaskOnFailure = killTaskOnFailure; |
| this.allocationFailureMessage = null; |
| } |
| |
| @Override |
| public void flatMap(Long key, Collector<String> collector) throws IOException { |
| |
| if (allocationFailureMessage != null) { |
| // Report the failure downstream, so that we can get the message from the output. |
| collector.collect(allocationFailureMessage); |
| allocationFailureMessage = null; |
| } |
| |
| if (failTask) { |
| // we fail the task, either by killing the JVM hard, or by throwing a user code |
| // exception. |
| if (killTaskOnFailure) { |
| Runtime.getRuntime().halt(-1); |
| } else { |
| throw new RuntimeException("Artificial user code exception."); |
| } |
| } |
| |
| // sanity check |
| if (null != valueState.value()) { |
| throw new IllegalStateException( |
| "This should never happen, keys are generated monotonously."); |
| } |
| |
| // store artificial data to blow up the state |
| valueState.update(RandomStringUtils.random(valueSize, true, true)); |
| } |
| |
| @Override |
| public void snapshotState(FunctionSnapshotContext functionSnapshotContext) {} |
| |
| @Override |
| public void initializeState(FunctionInitializationContext functionInitializationContext) |
| throws Exception { |
| ValueStateDescriptor<String> stateDescriptor = |
| new ValueStateDescriptor<>("state", String.class); |
| valueState = |
| functionInitializationContext.getKeyedStateStore().getState(stateDescriptor); |
| |
| ListStateDescriptor<MapperSchedulingAndFailureInfo> mapperInfoStateDescriptor = |
| new ListStateDescriptor<>("mapperState", MapperSchedulingAndFailureInfo.class); |
| schedulingAndFailureState = |
| functionInitializationContext |
| .getOperatorStateStore() |
| .getUnionListState(mapperInfoStateDescriptor); |
| |
| StreamingRuntimeContext runtimeContext = (StreamingRuntimeContext) getRuntimeContext(); |
| String allocationID = runtimeContext.getAllocationIDAsString(); |
| // Pattern of the name: "Flat Map -> Sink: Unnamed (4/4)#0". Remove "#0" part: |
| String taskNameWithSubtasks = runtimeContext.getTaskNameWithSubtasks().split("#")[0]; |
| |
| final int thisJvmPid = getJvmPid(); |
| final Set<Integer> killedJvmPids = new HashSet<>(); |
| |
| // here we check if the sticky scheduling worked as expected |
| if (functionInitializationContext.isRestored()) { |
| Iterable<MapperSchedulingAndFailureInfo> iterable = schedulingAndFailureState.get(); |
| |
| MapperSchedulingAndFailureInfo infoForThisTask = null; |
| List<MapperSchedulingAndFailureInfo> completeInfo = new ArrayList<>(); |
| if (iterable != null) { |
| for (MapperSchedulingAndFailureInfo testInfo : iterable) { |
| |
| completeInfo.add(testInfo); |
| |
| if (taskNameWithSubtasks.equals(testInfo.taskNameWithSubtask)) { |
| infoForThisTask = testInfo; |
| } |
| |
| if (testInfo.killedJvm) { |
| killedJvmPids.add(testInfo.jvmPid); |
| } |
| } |
| } |
| |
| Preconditions.checkNotNull(infoForThisTask, "Expected to find info here."); |
| |
| if (!isScheduledToCorrectAllocation(infoForThisTask, allocationID, killedJvmPids)) { |
| allocationFailureMessage = |
| String.format( |
| "Sticky allocation test failed: Subtask %s in attempt %d was rescheduled from allocation %s " |
| + "on JVM with PID %d to unexpected allocation %s on JVM with PID %d.\n" |
| + "Complete information from before the crash: %s.", |
| taskNameWithSubtasks, |
| runtimeContext.getAttemptNumber(), |
| infoForThisTask.allocationId, |
| infoForThisTask.jvmPid, |
| allocationID, |
| thisJvmPid, |
| completeInfo); |
| } |
| } |
| |
| // We determine which of the subtasks will produce the artificial failure |
| boolean failingTask = shouldTaskFailForThisAttempt(); |
| |
| // We take note of all the meta info that we require to check sticky scheduling in the |
| // next re-attempt |
| this.currentSchedulingAndFailureInfo = |
| new MapperSchedulingAndFailureInfo( |
| failingTask, |
| failingTask && killTaskOnFailure, |
| thisJvmPid, |
| taskNameWithSubtasks, |
| allocationID); |
| |
| schedulingAndFailureState.clear(); |
| schedulingAndFailureState.add(currentSchedulingAndFailureInfo); |
| } |
| |
| @Override |
| public void notifyCheckpointComplete(long checkpointId) { |
| // we can only fail the task after at least one checkpoint is completed to record |
| // progress. |
| failTask = currentSchedulingAndFailureInfo.failingTask; |
| } |
| |
| @Override |
| public void notifyCheckpointAborted(long checkpointId) {} |
| |
| private boolean shouldTaskFailForThisAttempt() { |
| RuntimeContext runtimeContext = getRuntimeContext(); |
| int numSubtasks = runtimeContext.getNumberOfParallelSubtasks(); |
| int subtaskIdx = runtimeContext.getIndexOfThisSubtask(); |
| int attempt = runtimeContext.getAttemptNumber(); |
| return (attempt % numSubtasks) == subtaskIdx; |
| } |
| |
| private boolean isScheduledToCorrectAllocation( |
| MapperSchedulingAndFailureInfo infoForThisTask, |
| String allocationID, |
| Set<Integer> killedJvmPids) { |
| |
| return (infoForThisTask.allocationId.equals(allocationID) |
| || killedJvmPids.contains(infoForThisTask.jvmPid)); |
| } |
| } |
| |
| /** |
| * This code is copied from Stack Overflow. |
| * |
| * <p><a |
| * href="https://stackoverflow.com/questions/35842">https://stackoverflow.com/questions/35842</a>, |
| * answer <a |
| * href="https://stackoverflow.com/a/12066696/9193881">https://stackoverflow.com/a/12066696/9193881</a> |
| * |
| * <p>Author: <a href="https://stackoverflow.com/users/446591/brad-mace">Brad Mace</a>) |
| */ |
| private static int getJvmPid() throws Exception { |
| java.lang.management.RuntimeMXBean runtime = |
| java.lang.management.ManagementFactory.getRuntimeMXBean(); |
| java.lang.reflect.Field jvm = runtime.getClass().getDeclaredField("jvm"); |
| jvm.setAccessible(true); |
| sun.management.VMManagement mgmt = (sun.management.VMManagement) jvm.get(runtime); |
| java.lang.reflect.Method pidMethod = mgmt.getClass().getDeclaredMethod("getProcessId"); |
| pidMethod.setAccessible(true); |
| |
| return (int) (Integer) pidMethod.invoke(mgmt); |
| } |
| |
| /** Records the information required to check sticky scheduling after a restart. */ |
| public static class MapperSchedulingAndFailureInfo implements Serializable { |
| |
| private static final long serialVersionUID = 1L; |
| |
| /** True iff this task inflicts a test failure. */ |
| final boolean failingTask; |
| |
| /** True iff this task kills its JVM. */ |
| final boolean killedJvm; |
| |
| /** PID of the task JVM. */ |
| final int jvmPid; |
| |
| /** Name and subtask index of the task. */ |
| final String taskNameWithSubtask; |
| |
| /** The current allocation id of this task. */ |
| final String allocationId; |
| |
| MapperSchedulingAndFailureInfo( |
| boolean failingTask, |
| boolean killedJvm, |
| int jvmPid, |
| String taskNameWithSubtask, |
| String allocationId) { |
| |
| this.failingTask = failingTask; |
| this.killedJvm = killedJvm; |
| this.jvmPid = jvmPid; |
| this.taskNameWithSubtask = taskNameWithSubtask; |
| this.allocationId = allocationId; |
| } |
| |
| @Override |
| public String toString() { |
| return "MapperTestInfo{" |
| + "failingTask=" |
| + failingTask |
| + ", killedJvm=" |
| + killedJvm |
| + ", jvmPid=" |
| + jvmPid |
| + ", taskNameWithSubtask='" |
| + taskNameWithSubtask |
| + '\'' |
| + ", allocationId='" |
| + allocationId |
| + '\'' |
| + '}'; |
| } |
| } |
| } |