Merge pull request #8720 from tvalentyn/cp_pr_8652

[BEAM-7203] Cherry-pick PR/8652 to the 2.13.0 release branch.
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
index 0e8a649..b0a1063 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
@@ -25,6 +25,7 @@
 import java.util.concurrent.Future;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
 import org.apache.beam.runners.core.construction.graph.PipelineTrimmer;
 import org.apache.beam.runners.core.metrics.MetricsPusher;
@@ -58,7 +59,14 @@
 
     // Don't let the fuser fuse any subcomponents of native transforms.
     Pipeline trimmedPipeline = PipelineTrimmer.trim(pipeline, translator.knownUrns());
-    Pipeline fusedPipeline = GreedyPipelineFuser.fuse(trimmedPipeline).toPipeline();
+
+    // Fused pipeline proto.
+    // TODO: Consider supporting partially-fused graphs.
+    RunnerApi.Pipeline fusedPipeline =
+        trimmedPipeline.getComponents().getTransformsMap().values().stream()
+                .anyMatch(proto -> ExecutableStage.URN.equals(proto.getSpec().getUrn()))
+            ? trimmedPipeline
+            : GreedyPipelineFuser.fuse(trimmedPipeline).toPipeline();
 
     if (pipelineOptions.getFilesToStage() == null) {
       pipelineOptions.setFilesToStage(
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index 8e7796f..0180496 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -64,6 +64,7 @@
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.storage.StorageLevel;
 import scala.Tuple2;
 
 /** Translates a bounded portable pipeline into a Spark job. */
@@ -112,6 +113,24 @@
         QueryablePipeline.forTransforms(
             pipeline.getRootTransformIdsList(), pipeline.getComponents());
     for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
+      // Pre-scan pipeline to count which pCollections are consumed as inputs more than once so
+      // their corresponding RDDs can later be cached.
+      for (String inputId : transformNode.getTransform().getInputsMap().values()) {
+        context.incrementConsumptionCountBy(inputId, 1);
+      }
+      // Executable stage consists of two parts: computation and extraction. This means the result
+      // of computation is an intermediate RDD, which we might also need to cache.
+      if (transformNode.getTransform().getSpec().getUrn().equals(ExecutableStage.URN)) {
+        context.incrementConsumptionCountBy(
+            getExecutableStageIntermediateId(transformNode),
+            transformNode.getTransform().getOutputsMap().size());
+      }
+      for (String outputId : transformNode.getTransform().getOutputsMap().values()) {
+        WindowedValueCoder outputCoder = getWindowedValueCoder(outputId, pipeline.getComponents());
+        context.putCoder(outputId, outputCoder);
+      }
+    }
+    for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
       urnToTransformTranslator
           .getOrDefault(
               transformNode.getTransform().getSpec().getUrn(),
@@ -141,18 +160,9 @@
 
     RunnerApi.Components components = pipeline.getComponents();
     String inputId = getInputId(transformNode);
-    PCollection inputPCollection = components.getPcollectionsOrThrow(inputId);
     Dataset inputDataset = context.popDataset(inputId);
     JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD();
-    PCollectionNode inputPCollectionNode = PipelineNode.pCollection(inputId, inputPCollection);
-    WindowedValueCoder<KV<K, V>> inputCoder;
-    try {
-      inputCoder =
-          (WindowedValueCoder)
-              WireCoders.instantiateRunnerWireCoder(inputPCollectionNode, components);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
+    WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
     KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
     Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder();
     Coder<V> inputValueCoder = inputKvCoder.getValueCoder();
@@ -200,18 +210,18 @@
     Dataset inputDataset = context.popDataset(inputPCollectionId);
     JavaRDD<WindowedValue<InputT>> inputRdd = ((BoundedDataset<InputT>) inputDataset).getRDD();
     Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
-    BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
+    BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
 
     ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
         broadcastVariablesBuilder = ImmutableMap.builder();
     for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
-      RunnerApi.Components components = stagePayload.getComponents();
+      RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
       String collectionId =
-          components
+          stagePayloadComponents
               .getTransformsOrThrow(sideInputId.getTransformId())
               .getInputsOrThrow(sideInputId.getLocalName());
       Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
-          broadcastSideInput(collectionId, components, context);
+          broadcastSideInput(collectionId, stagePayloadComponents, context);
       broadcastVariablesBuilder.put(collectionId, tuple2);
     }
 
@@ -219,14 +229,38 @@
         new SparkExecutableStageFunction<>(
             stagePayload,
             context.jobInfo,
-            outputMap,
+            outputExtractionMap,
             broadcastVariablesBuilder.build(),
             MetricsAccumulator.getInstance());
     JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function);
+    String intermediateId = getExecutableStageIntermediateId(transformNode);
+    context.pushDataset(
+        intermediateId,
+        new Dataset() {
+          @Override
+          public void cache(String storageLevel, Coder<?> coder) {
+            StorageLevel level = StorageLevel.fromString(storageLevel);
+            staged.persist(level);
+          }
+
+          @Override
+          public void action() {
+            // Empty function to force computation of RDD.
+            staged.foreach(TranslationUtils.emptyVoidFunction());
+          }
+
+          @Override
+          public void setName(String name) {
+            staged.setName(name);
+          }
+        });
+    // pop dataset to mark RDD as used
+    context.popDataset(intermediateId);
 
     for (String outputId : outputs.values()) {
       JavaRDD<WindowedValue<OutputT>> outputRdd =
-          staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputMap.get(outputId)));
+          staged.flatMap(
+              new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
       context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
     }
     if (outputs.isEmpty()) {
@@ -249,17 +283,9 @@
    */
   private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> broadcastSideInput(
       String collectionId, RunnerApi.Components components, SparkTranslationContext context) {
-    PCollection collection = components.getPcollectionsOrThrow(collectionId);
     @SuppressWarnings("unchecked")
     BoundedDataset<T> dataset = (BoundedDataset<T>) context.popDataset(collectionId);
-    PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, collection);
-    WindowedValueCoder<T> coder;
-    try {
-      coder =
-          (WindowedValueCoder<T>) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
+    WindowedValueCoder<T> coder = getWindowedValueCoder(collectionId, components);
     List<byte[]> bytes = dataset.getBytes(coder);
     Broadcast<List<byte[]>> broadcast = context.getSparkContext().broadcast(bytes);
     return new Tuple2<>(broadcast, coder);
@@ -324,4 +350,22 @@
   private static String getOutputId(PTransformNode transformNode) {
     return Iterables.getOnlyElement(transformNode.getTransform().getOutputsMap().values());
   }
+
+  private static <T> WindowedValueCoder<T> getWindowedValueCoder(
+      String pCollectionId, RunnerApi.Components components) {
+    PCollection pCollection = components.getPcollectionsOrThrow(pCollectionId);
+    PCollectionNode pCollectionNode = PipelineNode.pCollection(pCollectionId, pCollection);
+    WindowedValueCoder<T> coder;
+    try {
+      coder =
+          (WindowedValueCoder) WireCoders.instantiateRunnerWireCoder(pCollectionNode, components);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+    return coder;
+  }
+
+  private static String getExecutableStageIntermediateId(PTransformNode transformNode) {
+    return transformNode.getId();
+  }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
index 772e0d2..8c2cee8 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
@@ -17,12 +17,16 @@
  */
 package org.apache.beam.runners.spark.translation;
 
+import com.sun.istack.Nullable;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.spark.api.java.JavaSparkContext;
 
@@ -33,6 +37,9 @@
 public class SparkTranslationContext {
   private final JavaSparkContext jsc;
   final JobInfo jobInfo;
+  // Map pCollection IDs to the number of times they are consumed as inputs.
+  private final Map<String, Integer> consumptionCount = new HashMap<>();
+  private final Map<String, Coder> coderMap = new HashMap<>();
   private final Map<String, Dataset> datasets = new LinkedHashMap<>();
   private final Set<Dataset> leaves = new LinkedHashSet<>();
   final SerializablePipelineOptions serializablePipelineOptions;
@@ -51,7 +58,13 @@
   /** Add output of transform to context. */
   public void pushDataset(String pCollectionId, Dataset dataset) {
     dataset.setName(pCollectionId);
-    // TODO cache
+    SparkPipelineOptions sparkOptions =
+        serializablePipelineOptions.get().as(SparkPipelineOptions.class);
+    if (!sparkOptions.isCacheDisabled() && consumptionCount.getOrDefault(pCollectionId, 0) > 1) {
+      String storageLevel = sparkOptions.getStorageLevel();
+      @Nullable Coder coder = coderMap.get(pCollectionId);
+      dataset.cache(storageLevel, coder);
+    }
     datasets.put(pCollectionId, dataset);
     leaves.add(dataset);
   }
@@ -70,6 +83,15 @@
     }
   }
 
+  void incrementConsumptionCountBy(String pCollectionId, int addend) {
+    int count = consumptionCount.getOrDefault(pCollectionId, 0);
+    consumptionCount.put(pCollectionId, count + addend);
+  }
+
+  void putCoder(String pCollectionId, Coder coder) {
+    coderMap.put(pCollectionId, coder);
+  }
+
   /** Generate a unique pCollection id number to identify runner-generated sinks. */
   public int nextSinkId() {
     return sinkId++;
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
index d7d3428..eb9dce0 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
@@ -17,9 +17,11 @@
  */
 package org.apache.beam.runners.spark;
 
+import java.io.File;
 import java.io.Serializable;
+import java.nio.file.FileSystems;
 import java.util.Collections;
-import java.util.concurrent.Executors;
+import java.util.UUID;
 import java.util.concurrent.TimeUnit;
 import org.apache.beam.model.jobmanagement.v1.JobApi.JobState.Enum;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -46,8 +48,11 @@
 import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.ListeningExecutorService;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors;
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
+import org.junit.ClassRule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -56,15 +61,13 @@
  */
 public class SparkPortableExecutionTest implements Serializable {
 
+  @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
   private static final Logger LOG = LoggerFactory.getLogger(SparkPortableExecutionTest.class);
-
   private static ListeningExecutorService sparkJobExecutor;
 
   @BeforeClass
   public static void setup() {
-    // Restrict this to only one thread to avoid multiple Spark clusters up at the same time
-    // which is not suitable for memory-constraint environments, i.e. Jenkins.
-    sparkJobExecutor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
+    sparkJobExecutor = MoreExecutors.newDirectExecutorService();
   }
 
   @AfterClass
@@ -159,8 +162,117 @@
             pipelineProto,
             options.as(SparkPipelineOptions.class));
     jobInvocation.start();
-    while (jobInvocation.getState() != Enum.DONE) {
-      Thread.sleep(1000);
+    Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+  }
+
+  /**
+   * Verifies that each executable stage runs exactly once, even if that executable stage has
+   * multiple immediate outputs. While re-computation may be necessary in the event of failure,
+   * re-computation of a whole executable stage is expensive and can cause unexpected behavior when
+   * the executable stage has side effects (BEAM-7131).
+   *
+   * <pre>
+   *    |-> B -> GBK
+   * A -|
+   *    |-> C -> GBK
+   * </pre>
+   */
+  @Test(timeout = 120_000)
+  public void testExecStageWithMultipleOutputs() throws Exception {
+    PipelineOptions options = PipelineOptionsFactory.create();
+    options.setRunner(CrashingRunner.class);
+    options
+        .as(PortablePipelineOptions.class)
+        .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
+    Pipeline pipeline = Pipeline.create(options);
+    PCollection<KV<String, String>> a =
+        pipeline
+            .apply("impulse", Impulse.create())
+            .apply("A", ParDo.of(new DoFnWithSideEffect<>("A")));
+    PCollection<KV<String, String>> b = a.apply("B", ParDo.of(new DoFnWithSideEffect<>("B")));
+    PCollection<KV<String, String>> c = a.apply("C", ParDo.of(new DoFnWithSideEffect<>("C")));
+    // Use GBKs to force re-computation of executable stage unless cached.
+    b.apply(GroupByKey.create());
+    c.apply(GroupByKey.create());
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+    JobInvocation jobInvocation =
+        SparkJobInvoker.createJobInvocation(
+            "testExecStageWithMultipleOutputs",
+            "testExecStageWithMultipleOutputsRetrievalToken",
+            sparkJobExecutor,
+            pipelineProto,
+            options.as(SparkPipelineOptions.class));
+    jobInvocation.start();
+    Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+  }
+
+  /**
+   * Verifies that each executable stage runs exactly once, even if that executable stage has
+   * multiple downstream consumers. While re-computation may be necessary in the event of failure,
+   * re-computation of a whole executable stage is expensive and can cause unexpected behavior when
+   * the executable stage has side effects (BEAM-7131).
+   *
+   * <pre>
+   *           |-> G
+   * F -> GBK -|
+   *           |-> H
+   * </pre>
+   */
+  @Test(timeout = 120_000)
+  public void testExecStageWithMultipleConsumers() throws Exception {
+    PipelineOptions options = PipelineOptionsFactory.create();
+    options.setRunner(CrashingRunner.class);
+    options
+        .as(PortablePipelineOptions.class)
+        .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
+    Pipeline pipeline = Pipeline.create(options);
+    PCollection<KV<String, Iterable<String>>> f =
+        pipeline
+            .apply("impulse", Impulse.create())
+            .apply("F", ParDo.of(new DoFnWithSideEffect<>("F")))
+            // use GBK to prevent fusion of F, G, and H
+            .apply(GroupByKey.create());
+    f.apply("G", ParDo.of(new DoFnWithSideEffect<>("G")));
+    f.apply("H", ParDo.of(new DoFnWithSideEffect<>("H")));
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+    JobInvocation jobInvocation =
+        SparkJobInvoker.createJobInvocation(
+            "testExecStageWithMultipleConsumers",
+            "testExecStageWithMultipleConsumersRetrievalToken",
+            sparkJobExecutor,
+            pipelineProto,
+            options.as(SparkPipelineOptions.class));
+    jobInvocation.start();
+    Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+  }
+
+  /** A non-idempotent DoFn that cannot be run more than once without error. */
+  private class DoFnWithSideEffect<InputT> extends DoFn<InputT, KV<String, String>> {
+
+    private final String name;
+    private final File file;
+
+    DoFnWithSideEffect(String name) {
+      this.name = name;
+      String path =
+          FileSystems.getDefault()
+              .getPath(
+                  temporaryFolder.getRoot().getAbsolutePath(),
+                  String.format("%s-%s", this.name, UUID.randomUUID().toString()))
+              .toString();
+      file = new File(path);
+    }
+
+    @ProcessElement
+    public void process(ProcessContext context) throws Exception {
+      context.output(KV.of(name, name));
+      // Verify this DoFn has not run more than once by enacting a side effect via the local file
+      // system.
+      Assert.assertTrue(
+          String.format(
+              "Create file %s failed (DoFn %s should only have been run once).",
+              file.getAbsolutePath(), name),
+          file.createNewFile());
     }
   }
 }