| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.flink; |
| |
| import static org.hamcrest.MatcherAssert.assertThat; |
| |
| import java.io.Serializable; |
| import java.lang.reflect.Method; |
| import java.net.URI; |
| import java.util.Collections; |
| import java.util.Objects; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.ExecutionException; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.TimeUnit; |
| import org.apache.beam.model.pipeline.v1.RunnerApi; |
| import org.apache.beam.runners.core.construction.Environments; |
| import org.apache.beam.runners.core.construction.PipelineTranslation; |
| import org.apache.beam.runners.fnexecution.jobsubmission.JobInvocation; |
| import org.apache.beam.sdk.Pipeline; |
| import org.apache.beam.sdk.io.GenerateSequence; |
| import org.apache.beam.sdk.options.PipelineOptionsFactory; |
| import org.apache.beam.sdk.options.PortablePipelineOptions; |
| import org.apache.beam.sdk.state.BagState; |
| import org.apache.beam.sdk.state.StateSpec; |
| import org.apache.beam.sdk.state.StateSpecs; |
| import org.apache.beam.sdk.state.TimeDomain; |
| import org.apache.beam.sdk.state.Timer; |
| import org.apache.beam.sdk.state.TimerSpec; |
| import org.apache.beam.sdk.state.TimerSpecs; |
| import org.apache.beam.sdk.state.ValueState; |
| import org.apache.beam.sdk.transforms.DoFn; |
| import org.apache.beam.sdk.transforms.Impulse; |
| import org.apache.beam.sdk.transforms.InferableFunction; |
| import org.apache.beam.sdk.transforms.MapElements; |
| import org.apache.beam.sdk.transforms.ParDo; |
| import org.apache.beam.sdk.values.KV; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ListeningExecutorService; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors; |
| import org.apache.flink.api.common.JobID; |
| import org.apache.flink.configuration.CheckpointingOptions; |
| import org.apache.flink.configuration.Configuration; |
| import org.apache.flink.configuration.RestOptions; |
| import org.apache.flink.runtime.client.JobStatusMessage; |
| import org.apache.flink.runtime.jobgraph.JobGraph; |
| import org.apache.flink.runtime.jobgraph.JobStatus; |
| import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings; |
| import org.apache.flink.runtime.minicluster.MiniCluster; |
| import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; |
| import org.hamcrest.Matchers; |
| import org.hamcrest.core.IsIterableContaining; |
| import org.joda.time.Duration; |
| import org.junit.After; |
| import org.junit.AfterClass; |
| import org.junit.BeforeClass; |
| import org.junit.ClassRule; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.rules.TemporaryFolder; |
| import org.junit.rules.Timeout; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** |
| * Tests that Flink's Savepoints work with the Flink Runner. This includes taking a savepoint of a |
| * running pipeline, shutting down the pipeline, and restarting the pipeline from the savepoint with |
| * a different parallelism. |
| */ |
| public class FlinkSavepointTest implements Serializable { |
| |
| private static final Logger LOG = LoggerFactory.getLogger(FlinkSavepointTest.class); |
| |
| /** Flink cluster that runs over the lifespan of the tests. */ |
| private static transient MiniCluster flinkCluster; |
| |
| /** Static for synchronization between the pipeline state and the test. */ |
| private static volatile CountDownLatch oneShotLatch; |
| |
| /** Temporary folder for savepoints. */ |
| @ClassRule public static transient TemporaryFolder tempFolder = new TemporaryFolder(); |
| |
| /** Each test has a timeout of 60 seconds (for safety). */ |
| @Rule public Timeout timeout = new Timeout(60, TimeUnit.SECONDS); |
| |
| @BeforeClass |
| public static void beforeClass() throws Exception { |
| Configuration config = new Configuration(); |
| // Avoid port collision in parallel tests |
| config.setInteger(RestOptions.PORT, 0); |
| config.setString(CheckpointingOptions.STATE_BACKEND, "filesystem"); |
| |
| String savepointPath = "file://" + tempFolder.getRoot().getAbsolutePath(); |
| LOG.info("Savepoints will be written to {}", savepointPath); |
| // It is necessary to configure the checkpoint directory for the state backend, |
| // even though we only create savepoints in this test. |
| config.setString(CheckpointingOptions.CHECKPOINTS_DIRECTORY, savepointPath); |
| // Checkpoints will go into a subdirectory of this directory |
| config.setString(CheckpointingOptions.SAVEPOINT_DIRECTORY, savepointPath); |
| |
| MiniClusterConfiguration clusterConfig = |
| new MiniClusterConfiguration.Builder() |
| .setConfiguration(config) |
| .setNumTaskManagers(2) |
| .setNumSlotsPerTaskManager(2) |
| .build(); |
| |
| flinkCluster = new MiniCluster(clusterConfig); |
| flinkCluster.start(); |
| } |
| |
| @AfterClass |
| public static void afterClass() throws Exception { |
| flinkCluster.close(); |
| flinkCluster = null; |
| } |
| |
| @After |
| public void afterTest() throws Exception { |
| for (JobStatusMessage jobStatusMessage : flinkCluster.listJobs().get()) { |
| if (jobStatusMessage.getJobState() == JobStatus.RUNNING) { |
| flinkCluster.cancelJob(jobStatusMessage.getJobId()).get(); |
| } |
| } |
| while (!flinkCluster.listJobs().get().stream() |
| .allMatch(job -> job.getJobState().isTerminalState())) { |
| Thread.sleep(50); |
| } |
| } |
| |
| @Test |
| public void testSavepointRestoreLegacy() throws Exception { |
| runSavepointAndRestore(false); |
| } |
| |
| @Test |
| public void testSavepointRestorePortable() throws Exception { |
| runSavepointAndRestore(true); |
| } |
| |
| private void runSavepointAndRestore(boolean isPortablePipeline) throws Exception { |
| FlinkPipelineOptions options = PipelineOptionsFactory.as(FlinkPipelineOptions.class); |
| options.setStreaming(true); |
| // Initial parallelism |
| options.setParallelism(2); |
| options.setRunner(FlinkRunner.class); |
| |
| oneShotLatch = new CountDownLatch(1); |
| Pipeline pipeline = Pipeline.create(options); |
| createStreamingJob(pipeline, false, isPortablePipeline); |
| |
| final JobID jobID; |
| if (isPortablePipeline) { |
| jobID = executePortable(pipeline); |
| } else { |
| jobID = executeLegacy(pipeline); |
| } |
| oneShotLatch.await(); |
| String savepointDir = takeSavepointAndCancelJob(jobID); |
| |
| oneShotLatch = new CountDownLatch(1); |
| // Increase parallelism |
| options.setParallelism(4); |
| pipeline = Pipeline.create(options); |
| createStreamingJob(pipeline, true, isPortablePipeline); |
| |
| if (isPortablePipeline) { |
| restoreFromSavepointPortable(pipeline, savepointDir); |
| } else { |
| restoreFromSavepointLegacy(pipeline, savepointDir); |
| } |
| oneShotLatch.await(); |
| } |
| |
| private JobID executeLegacy(Pipeline pipeline) throws Exception { |
| JobGraph jobGraph = getJobGraph(pipeline); |
| flinkCluster.submitJob(jobGraph).get(); |
| return jobGraph.getJobID(); |
| } |
| |
| private JobID executePortable(Pipeline pipeline) throws Exception { |
| pipeline |
| .getOptions() |
| .as(PortablePipelineOptions.class) |
| .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED); |
| pipeline.getOptions().as(FlinkPipelineOptions.class).setFlinkMaster(getFlinkMaster()); |
| |
| RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline); |
| |
| ListeningExecutorService executorService = |
| MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); |
| FlinkPipelineOptions pipelineOptions = pipeline.getOptions().as(FlinkPipelineOptions.class); |
| try { |
| JobInvocation jobInvocation = |
| FlinkJobInvoker.createJobInvocation( |
| "id", |
| "none", |
| executorService, |
| pipelineProto, |
| pipelineOptions, |
| new FlinkPipelineRunner(pipelineOptions, null, Collections.emptyList())); |
| |
| jobInvocation.start(); |
| |
| return waitForJobToBeReady(); |
| } finally { |
| executorService.shutdown(); |
| } |
| } |
| |
| private String getFlinkMaster() throws Exception { |
| final URI uri; |
| Method getRestAddress = flinkCluster.getClass().getMethod("getRestAddress"); |
| if (getRestAddress.getReturnType().equals(URI.class)) { |
| // Flink 1.5 way |
| uri = (URI) getRestAddress.invoke(flinkCluster); |
| } else if (getRestAddress.getReturnType().equals(CompletableFuture.class)) { |
| @SuppressWarnings("unchecked") |
| CompletableFuture<URI> future = (CompletableFuture<URI>) getRestAddress.invoke(flinkCluster); |
| uri = future.get(); |
| } else { |
| throw new RuntimeException("Could not determine Rest address for this Flink version."); |
| } |
| return uri.getHost() + ":" + uri.getPort(); |
| } |
| |
| private JobID waitForJobToBeReady() throws InterruptedException, ExecutionException { |
| while (true) { |
| JobStatusMessage jobStatus = Iterables.getFirst(flinkCluster.listJobs().get(), null); |
| if (jobStatus != null && jobStatus.getJobState() == JobStatus.RUNNING) { |
| return jobStatus.getJobId(); |
| } |
| Thread.sleep(100); |
| } |
| } |
| |
| private String takeSavepointAndCancelJob(JobID jobID) throws Exception { |
| Exception exception = null; |
| // try multiple times because the job might not be ready yet |
| for (int i = 0; i < 10; i++) { |
| try { |
| return flinkCluster.triggerSavepoint(jobID, null, true).get(); |
| } catch (Exception e) { |
| exception = e; |
| Thread.sleep(100); |
| } |
| } |
| throw exception; |
| } |
| |
| private void restoreFromSavepointLegacy(Pipeline pipeline, String savepointDir) |
| throws ExecutionException, InterruptedException { |
| JobGraph jobGraph = getJobGraph(pipeline); |
| SavepointRestoreSettings savepointSettings = SavepointRestoreSettings.forPath(savepointDir); |
| jobGraph.setSavepointRestoreSettings(savepointSettings); |
| flinkCluster.submitJob(jobGraph).get(); |
| } |
| |
| private void restoreFromSavepointPortable(Pipeline pipeline, String savepointDir) |
| throws Exception { |
| FlinkPipelineOptions flinkOptions = pipeline.getOptions().as(FlinkPipelineOptions.class); |
| flinkOptions.setSavepointPath(savepointDir); |
| executePortable(pipeline); |
| } |
| |
| private JobGraph getJobGraph(Pipeline pipeline) { |
| FlinkRunner flinkRunner = FlinkRunner.fromOptions(pipeline.getOptions()); |
| return flinkRunner.getJobGraph(pipeline); |
| } |
| |
| private static PCollection createStreamingJob( |
| Pipeline pipeline, boolean restored, boolean isPortablePipeline) { |
| final PCollection<KV<String, Long>> key; |
| if (isPortablePipeline) { |
| key = |
| pipeline |
| .apply(Impulse.create()) |
| .apply( |
| MapElements.via( |
| new InferableFunction<byte[], KV<String, Void>>() { |
| @Override |
| public KV<String, Void> apply(byte[] input) throws Exception { |
| // This only writes data to one of the two initial partitions. |
| // We want to test this due to |
| // https://jira.apache.org/jira/browse/BEAM-7144 |
| return KV.of("key", null); |
| } |
| })) |
| .apply( |
| ParDo.of( |
| new DoFn<KV<String, Void>, KV<String, Long>>() { |
| @StateId("nextInteger") |
| private final StateSpec<ValueState<Long>> valueStateSpec = |
| StateSpecs.value(); |
| |
| @TimerId("timer") |
| private final TimerSpec timer = |
| TimerSpecs.timer(TimeDomain.PROCESSING_TIME); |
| |
| @ProcessElement |
| public void processElement( |
| ProcessContext context, @TimerId("timer") Timer timer) { |
| |
| timer.offset(Duration.ZERO).setRelative(); |
| } |
| |
| @OnTimer("timer") |
| public void onTimer( |
| OnTimerContext context, |
| @StateId("nextInteger") ValueState<Long> nextInteger, |
| @TimerId("timer") Timer timer) { |
| Long current = nextInteger.read(); |
| if (current == null) { |
| current = -1L; |
| } |
| long next = current + 1; |
| nextInteger.write(next); |
| context.output(KV.of("key", next)); |
| timer.offset(Duration.millis(100)).setRelative(); |
| } |
| })); |
| } else { |
| key = |
| pipeline |
| .apply(GenerateSequence.from(0)) |
| .apply( |
| ParDo.of( |
| new DoFn<Long, KV<String, Long>>() { |
| @ProcessElement |
| public void processElement(ProcessContext context) { |
| context.output(KV.of("key", context.element())); |
| } |
| })); |
| } |
| if (restored) { |
| return key.apply( |
| ParDo.of( |
| new DoFn<KV<String, Long>, String>() { |
| |
| @StateId("valueState") |
| private final StateSpec<ValueState<Integer>> valueStateSpec = StateSpecs.value(); |
| |
| @StateId("bagState") |
| private final StateSpec<BagState<Integer>> bagStateSpec = StateSpecs.bag(); |
| |
| @ProcessElement |
| public void processElement( |
| ProcessContext context, |
| @StateId("valueState") ValueState<Integer> intValueState, |
| @StateId("bagState") BagState<Integer> intBagState) { |
| assertThat(intValueState.read(), Matchers.is(42)); |
| assertThat(intBagState.read(), IsIterableContaining.hasItems(40, 1, 1)); |
| oneShotLatch.countDown(); |
| } |
| })); |
| } else { |
| return key.apply( |
| ParDo.of( |
| new DoFn<KV<String, Long>, String>() { |
| |
| @StateId("valueState") |
| private final StateSpec<ValueState<Integer>> valueStateSpec = StateSpecs.value(); |
| |
| @StateId("bagState") |
| private final StateSpec<BagState<Integer>> bagStateSpec = StateSpecs.bag(); |
| |
| @ProcessElement |
| public void processElement( |
| ProcessContext context, |
| @StateId("valueState") ValueState<Integer> intValueState, |
| @StateId("bagState") BagState<Integer> intBagState) { |
| Long value = Objects.requireNonNull(context.element().getValue()); |
| if (value == 0L) { |
| intValueState.write(42); |
| intBagState.add(40); |
| intBagState.add(1); |
| intBagState.add(1); |
| oneShotLatch.countDown(); |
| } |
| } |
| })); |
| } |
| } |
| } |