Merge pull request #10043 [BEAM-8597] Allow TestStream trigger tests to run on other runners.

diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 3ea3643..557f45f 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -1905,6 +1905,7 @@
           mustRunAfter = [
             ':runners:flink:1.9:job-server-container:docker',
             ':runners:flink:1.9:job-server:shadowJar',
+            ':runners:spark:job-server:shadowJar',
             ':sdks:python:container:py2:docker',
             ':sdks:python:container:py35:docker',
             ':sdks:python:container:py36:docker',
@@ -1958,6 +1959,7 @@
         addPortableWordCountTask(true, "PortableRunner")
         addPortableWordCountTask(false, "FlinkRunner")
         addPortableWordCountTask(true, "FlinkRunner")
+        addPortableWordCountTask(false, "SparkRunner")
       }
     }
   }
diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index 3254a85..6281b94 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -200,6 +200,7 @@
       excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
       if (config.streaming) {
         excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
+        excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'  // BEAM-8598
         excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
       } else {
         excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
index 7e298e7..dcf15f4 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
@@ -56,7 +56,6 @@
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
 import org.slf4j.Logger;
@@ -253,7 +252,7 @@
       }
 
       // Check to see if processing the request failed.
-      throwIfFailure(processResponse);
+      MoreFutures.get(processResponse);
 
       waitForInboundTermination.awaitCompletion();
       WindowedValue<KV<byte[], TargetWindowT>> sideInputWindow = outputValue.poll();
@@ -300,22 +299,10 @@
                               processBundleDescriptor.toBuilder().setId(descriptorId).build())
                           .build())
                   .build());
-      throwIfFailure(response);
+      // Check if the bundle descriptor is registered successfully.
+      MoreFutures.get(response);
       processBundleDescriptorId = descriptorId;
     }
     return processBundleDescriptorId;
   }
-
-  private static InstructionResponse throwIfFailure(
-      CompletionStage<InstructionResponse> responseFuture)
-      throws ExecutionException, InterruptedException {
-    InstructionResponse response = MoreFutures.get(responseFuture);
-    if (!Strings.isNullOrEmpty(response.getError())) {
-      throw new IllegalStateException(
-          String.format(
-              "Client failed to process %s with error [%s].",
-              response.getInstructionId(), response.getError()));
-    }
-    return response;
-  }
 }
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
index 0a3346c..bf42c4d 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
@@ -370,12 +370,8 @@
    * elements consumed from the upstream read operation.
    *
    * <p>May be called at any time, including before start() and after finish().
-   *
-   * @throws InterruptedException
-   * @throws ExecutionException
    */
-  public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress()
-      throws InterruptedException, ExecutionException {
+  public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress() {
     // processBundleId may be reset if this bundle finishes asynchronously.
     String processBundleId = this.processBundleId;
 
@@ -393,13 +389,7 @@
 
     return instructionRequestHandler
         .handle(processBundleRequest)
-        .thenApply(
-            response -> {
-              if (!response.getError().isEmpty()) {
-                throw new IllegalStateException(response.getError());
-              }
-              return response.getProcessBundleProgress();
-            });
+        .thenApply(InstructionResponse::getProcessBundleProgress);
   }
 
   /** Returns the final metrics returned by the SDK harness when it completes the bundle. */
@@ -636,53 +626,36 @@
     return true;
   }
 
-  private static CompletionStage<BeamFnApi.InstructionResponse> throwIfFailure(
+  private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
       CompletionStage<InstructionResponse> responseFuture) {
     return responseFuture.thenApply(
         response -> {
-          if (!response.getError().isEmpty()) {
-            throw new IllegalStateException(
-                String.format(
-                    "Client failed to process %s with error [%s].",
-                    response.getInstructionId(), response.getError()));
+          switch (response.getResponseCase()) {
+            case PROCESS_BUNDLE:
+              return response.getProcessBundle();
+            default:
+              throw new IllegalStateException(
+                  String.format(
+                      "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
+                      TextFormat.printToString(response)));
           }
-          return response;
         });
   }
 
-  private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
-      CompletionStage<InstructionResponse> responseFuture) {
-    return throwIfFailure(responseFuture)
-        .thenApply(
-            response -> {
-              switch (response.getResponseCase()) {
-                case PROCESS_BUNDLE:
-                  return response.getProcessBundle();
-                default:
-                  throw new IllegalStateException(
-                      String.format(
-                          "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
-                          TextFormat.printToString(response)));
-              }
-            });
-  }
-
   private static CompletionStage<BeamFnApi.RegisterResponse> getRegisterResponse(
-      CompletionStage<InstructionResponse> responseFuture)
-      throws ExecutionException, InterruptedException {
-    return throwIfFailure(responseFuture)
-        .thenApply(
-            response -> {
-              switch (response.getResponseCase()) {
-                case REGISTER:
-                  return response.getRegister();
-                default:
-                  throw new IllegalStateException(
-                      String.format(
-                          "SDK harness returned wrong kind of response to RegisterRequest: %s",
-                          TextFormat.printToString(response)));
-              }
-            });
+      CompletionStage<InstructionResponse> responseFuture) {
+    return responseFuture.thenApply(
+        response -> {
+          switch (response.getResponseCase()) {
+            case REGISTER:
+              return response.getRegister();
+            default:
+              throw new IllegalStateException(
+                  String.format(
+                      "SDK harness returned wrong kind of response to RegisterRequest: %s",
+                      TextFormat.printToString(response)));
+          }
+        });
   }
 
   private static void cancelIfNotNull(CompletionStage<?> future) {
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
new file mode 100644
index 0000000..1e5fa53
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+
+/**
+ * A handler which is invoked when the SDK returns {@link BeamFnApi.DelayedBundleApplication}s as
+ * part of the bundle completion.
+ *
+ * <p>These bundle applications must be resumed otherwise data loss will occur.
+ *
+ * <p>See <a href="https://s.apache.org/beam-breaking-fusion">breaking the fusion barrier</a> for
+ * further details.
+ */
+public interface BundleCheckpointHandler {
+  void onCheckpoint(BeamFnApi.ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
new file mode 100644
index 0000000..849663b
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+
+/**
+ * A handler for the runner when a finalization request has been received.
+ *
+ * <p>The runner is responsible for finalizing the bundle when all output from the bundle has been
+ * durably persisted.
+ *
+ * <p>See <a href="https://s.apache.org/beam-finalizing-bundles">finalizing bundles</a> for further
+ * details.
+ */
+public interface BundleFinalizationHandler {
+  void requestsFinalization(ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
index b655732..8a9dc75 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
@@ -20,7 +20,10 @@
 import java.util.concurrent.CompletionStage;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
 
-/** Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. */
+/**
+ * Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. Any
+ * error responses will be converted to exceptionally completed futures.
+ */
 public interface InstructionRequestHandler extends AutoCloseable {
   CompletionStage<BeamFnApi.InstructionResponse> handle(BeamFnApi.InstructionRequest request);
 }
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
index 1ee5184..86232dd 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
@@ -95,6 +95,9 @@
      *   // send all main input elements ...
      * }
      * }</pre>
+     *
+     * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+     * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
      */
     public ActiveBundle newBundle(
         Map<String, RemoteOutputReceiver<?>> outputReceivers,
@@ -122,6 +125,47 @@
      * try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
      *   FnDataReceiver<InputT> inputReceiver =
      *       (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
+     *   // send all main input elements ...
+     * }
+     * }</pre>
+     *
+     * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+     * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
+     */
+    public ActiveBundle newBundle(
+        Map<String, RemoteOutputReceiver<?>> outputReceivers,
+        StateRequestHandler stateRequestHandler,
+        BundleProgressHandler progressHandler) {
+      return newBundle(
+          outputReceivers,
+          stateRequestHandler,
+          progressHandler,
+          request -> {
+            throw new UnsupportedOperationException(
+                String.format(
+                    "The %s does not have a registered bundle checkpoint handler.",
+                    ActiveBundle.class.getSimpleName()));
+          },
+          request -> {
+            throw new UnsupportedOperationException(
+                String.format(
+                    "The %s does not have a registered bundle finalization handler.",
+                    ActiveBundle.class.getSimpleName()));
+          });
+    }
+
+    /**
+     * Start a new bundle for the given {@link BeamFnApi.ProcessBundleDescriptor} identifier.
+     *
+     * <p>The input channels for the returned {@link ActiveBundle} are derived from the instructions
+     * in the {@link BeamFnApi.ProcessBundleDescriptor}.
+     *
+     * <p>NOTE: It is important to {@link #close()} each bundle after all elements are emitted.
+     *
+     * <pre>{@code
+     * try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
+     *   FnDataReceiver<InputT> inputReceiver =
+     *       (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
      *   // send all elements ...
      * }
      * }</pre>
@@ -129,7 +173,9 @@
     public ActiveBundle newBundle(
         Map<String, RemoteOutputReceiver<?>> outputReceivers,
         StateRequestHandler stateRequestHandler,
-        BundleProgressHandler progressHandler) {
+        BundleProgressHandler progressHandler,
+        BundleCheckpointHandler checkpointHandler,
+        BundleFinalizationHandler finalizationHandler) {
       String bundleId = idGenerator.getId();
 
       final CompletionStage<BeamFnApi.InstructionResponse> genericResponse =
@@ -175,7 +221,9 @@
           dataReceiversBuilder.build(),
           outputClients,
           stateDelegator.registerForProcessBundleInstructionId(bundleId, stateRequestHandler),
-          progressHandler);
+          progressHandler,
+          checkpointHandler,
+          finalizationHandler);
     }
 
     private <OutputT> InboundDataClient attachReceiver(
@@ -193,6 +241,8 @@
     private final Map<String, InboundDataClient> outputClients;
     private final StateDelegator.Registration stateRegistration;
     private final BundleProgressHandler progressHandler;
+    private final BundleCheckpointHandler checkpointHandler;
+    private final BundleFinalizationHandler finalizationHandler;
 
     private ActiveBundle(
         String bundleId,
@@ -200,13 +250,17 @@
         Map<String, CloseableFnDataReceiver> inputReceivers,
         Map<String, InboundDataClient> outputClients,
         StateDelegator.Registration stateRegistration,
-        BundleProgressHandler progressHandler) {
+        BundleProgressHandler progressHandler,
+        BundleCheckpointHandler checkpointHandler,
+        BundleFinalizationHandler finalizationHandler) {
       this.bundleId = bundleId;
       this.response = response;
       this.inputReceivers = inputReceivers;
       this.outputClients = outputClients;
       this.stateRegistration = stateRegistration;
       this.progressHandler = progressHandler;
+      this.checkpointHandler = checkpointHandler;
+      this.finalizationHandler = finalizationHandler;
     }
 
     /** Returns an id used to represent this bundle. */
@@ -256,13 +310,15 @@
           BeamFnApi.ProcessBundleResponse completedResponse = MoreFutures.get(response);
           progressHandler.onCompleted(completedResponse);
           if (completedResponse.getResidualRootsCount() > 0) {
-            throw new IllegalStateException(
-                "TODO: [BEAM-2939] residual roots in process bundle response not yet supported.");
+            checkpointHandler.onCheckpoint(completedResponse);
+          }
+          if (completedResponse.getRequiresFinalization()) {
+            finalizationHandler.requestsFinalization(completedResponse);
           }
         } else {
           // TODO: [BEAM-3962] Handle aborting the bundle being processed.
           throw new IllegalStateException(
-              "Processing bundle failed, " + "TODO: [BEAM-3962] abort bundle.");
+              "Processing bundle failed, TODO: [BEAM-3962] abort bundle.");
         }
       } catch (Exception e) {
         if (exception == null) {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
index 6a28735..089c8d1 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
@@ -31,6 +31,7 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 
 import java.util.ArrayList;
@@ -41,6 +42,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
@@ -623,6 +625,110 @@
     assertThat(requests.get(1).getProcessBundle().getCacheTokensList(), is(cacheTokens));
   }
 
+  @Test
+  public void testBundleCheckpointCallback() throws Exception {
+    Exception testException = new Exception();
+
+    InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+    CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+    CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+    when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+        .thenReturn(createRegisterResponse())
+        .thenReturn(processBundleResponseFuture);
+
+    FullWindowedValueCoder<String> coder =
+        FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+    BundleProcessor processor =
+        sdkHarnessClient.getProcessor(
+            descriptor,
+            Collections.singletonMap(
+                "inputPC",
+                RemoteInputDestination.of(
+                    (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+    when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+    when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+    RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+    BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+    BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+    BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+    ProcessBundleResponse response =
+        ProcessBundleResponse.newBuilder()
+            .addResidualRoots(DelayedBundleApplication.getDefaultInstance())
+            .build();
+    ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+    try (ActiveBundle activeBundle =
+        processor.newBundle(
+            ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+            (request) -> {
+              throw new UnsupportedOperationException();
+            },
+            mockProgressHandler,
+            mockCheckpointHandler,
+            mockFinalizationHandler)) {
+      processBundleResponseFuture.complete(
+          InstructionResponse.newBuilder().setProcessBundle(response).build());
+    }
+
+    verify(mockProgressHandler).onCompleted(response);
+    verify(mockCheckpointHandler).onCheckpoint(response);
+    verifyZeroInteractions(mockFinalizationHandler);
+  }
+
+  @Test
+  public void testBundleFinalizationCallback() throws Exception {
+    Exception testException = new Exception();
+
+    InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+    CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+    CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+    when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+        .thenReturn(createRegisterResponse())
+        .thenReturn(processBundleResponseFuture);
+
+    FullWindowedValueCoder<String> coder =
+        FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+    BundleProcessor processor =
+        sdkHarnessClient.getProcessor(
+            descriptor,
+            Collections.singletonMap(
+                "inputPC",
+                RemoteInputDestination.of(
+                    (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+    when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+    when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+    RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+    BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+    BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+    BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+    ProcessBundleResponse response =
+        ProcessBundleResponse.newBuilder().setRequiresFinalization(true).build();
+    ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+    try (ActiveBundle activeBundle =
+        processor.newBundle(
+            ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+            (request) -> {
+              throw new UnsupportedOperationException();
+            },
+            mockProgressHandler,
+            mockCheckpointHandler,
+            mockFinalizationHandler)) {
+      processBundleResponseFuture.complete(
+          InstructionResponse.newBuilder().setProcessBundle(response).build());
+    }
+
+    verify(mockProgressHandler).onCompleted(response);
+    verify(mockFinalizationHandler).requestsFinalization(response);
+    verifyZeroInteractions(mockCheckpointHandler);
+  }
+
   private CompletableFuture<InstructionResponse> createRegisterResponse() {
     return CompletableFuture.completedFuture(
         InstructionResponse.newBuilder()
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
new file mode 100644
index 0000000..55999ce
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Subcategory for {@link UsesTestStream} tests which use {@link TestStream} # across multiple
+ * stages. Some Runners do not properly support quiescence in a way that {@link TestStream} demands
+ * it.
+ */
+public interface UsesTestStreamWithMultipleStages extends UsesTestStream {}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
index 5e4cdcb..e48b6b2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.testing;
 
+import static org.apache.beam.sdk.transforms.windowing.Window.into;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -28,12 +29,22 @@
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.testing.TestStream.Builder;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Keys;
 import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.WithKeys;
@@ -44,6 +55,7 @@
 import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.transforms.windowing.Never;
 import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,8 +63,10 @@
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Rule;
@@ -263,7 +277,7 @@
     FixedWindows windows = FixedWindows.of(Duration.standardHours(6));
     PCollection<String> windowedValues =
         p.apply(stream)
-            .apply(Window.into(windows))
+            .apply(into(windows))
             .apply(WithKeys.of(1))
             .apply(GroupByKey.create())
             .apply(Values.create())
@@ -387,6 +401,74 @@
   }
 
   @Test
+  @Category({ValidatesRunner.class, UsesTestStream.class, UsesTestStreamWithMultipleStages.class})
+  public void testMultiStage() throws Exception {
+    TestStream<String> testStream =
+        TestStream.create(StringUtf8Coder.of())
+            .addElements("before") // before
+            .advanceWatermarkTo(Instant.ofEpochSecond(0)) // BEFORE
+            .addElements(TimestampedValue.of("after", Instant.ofEpochSecond(10))) // after
+            .advanceWatermarkToInfinity(); // AFTER
+
+    PCollection<String> input = p.apply(testStream);
+
+    PCollection<String> grouped =
+        input
+            .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))))
+            .apply(
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of(e, e)))
+            .apply(GroupByKey.create())
+            .apply(Keys.create())
+            .apply("Upper", MapElements.into(TypeDescriptors.strings()).via(String::toUpperCase))
+            .apply("Rewindow", Window.into(new GlobalWindows()));
+
+    PCollection<String> result =
+        PCollectionList.of(ImmutableList.of(input, grouped))
+            .apply(Flatten.pCollections())
+            .apply(
+                "Key",
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of("key", e)))
+            .apply(
+                ParDo.of(
+                    new DoFn<KV<String, String>, String>() {
+                      @StateId("seen")
+                      private final StateSpec<ValueState<String>> seenSpec =
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @TimerId("emit")
+                      private final TimerSpec emitSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+                      @ProcessElement
+                      public void process(
+                          ProcessContext context,
+                          @StateId("seen") ValueState<String> seenState,
+                          @TimerId("emit") Timer emitTimer) {
+                        String element = context.element().getValue();
+                        if (seenState.read() == null) {
+                          seenState.write(element);
+                        } else {
+                          seenState.write(seenState.read() + "," + element);
+                        }
+                        emitTimer.set(Instant.ofEpochSecond(100));
+                      }
+
+                      @OnTimer("emit")
+                      public void onEmit(
+                          OnTimerContext context, @StateId("seen") ValueState<String> seenState) {
+                        context.output(seenState.read());
+                      }
+                    }));
+
+    PAssert.that(result).containsInAnyOrder("before,BEFORE,after,AFTER");
+
+    p.run().waitUntilFinish();
+  }
+
+  @Test
   @Category(UsesTestStreamWithProcessingTime.class)
   public void testCoder() throws Exception {
     TestStream<String> testStream =
diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd
new file mode 100644
index 0000000..0204da8
--- /dev/null
+++ b/sdks/python/apache_beam/metrics/cells.pxd
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+cimport libc.stdint
+
+
+cdef class MetricCell(object):
+  cdef object _lock
+  cpdef bint update(self, value) except -1
+
+
+cdef class CounterCell(MetricCell):
+  cdef readonly libc.stdint.int64_t value
+
+  @cython.locals(ivalue=libc.stdint.int64_t)
+  cpdef bint update(self, value) except -1
+
+
+cdef class DistributionCell(MetricCell):
+  cdef readonly DistributionData data
+
+  @cython.locals(ivalue=libc.stdint.int64_t)
+  cdef inline bint _update(self, value) except -1
+
+
+cdef class GaugeCell(MetricCell):
+  cdef readonly object data
+
+
+cdef class DistributionData(object):
+  cdef readonly libc.stdint.int64_t sum
+  cdef readonly libc.stdint.int64_t count
+  cdef readonly libc.stdint.int64_t min
+  cdef readonly libc.stdint.int64_t max
diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py
index e7336e4..d30dd2a 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -30,12 +30,16 @@
 
 from google.protobuf import timestamp_pb2
 
-from apache_beam.metrics.metricbase import Counter
-from apache_beam.metrics.metricbase import Distribution
-from apache_beam.metrics.metricbase import Gauge
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import metrics_pb2
 
+try:
+  import cython
+except ImportError:
+  class fake_cython:
+    compiled = False
+  globals()['cython'] = fake_cython
+
 __all__ = ['DistributionResult', 'GaugeResult']
 
 
@@ -52,11 +56,17 @@
   def __init__(self):
     self._lock = threading.Lock()
 
+  def update(self, value):
+    raise NotImplementedError
+
   def get_cumulative(self):
     raise NotImplementedError
 
+  def __reduce__(self):
+    raise NotImplementedError
 
-class CounterCell(Counter, MetricCell):
+
+class CounterCell(MetricCell):
   """For internal use only; no backwards-compatibility guarantees.
 
   Tracks the current value and delta of a counter metric.
@@ -80,27 +90,41 @@
     return result
 
   def inc(self, n=1):
-    with self._lock:
-      self.value += n
+    self.update(n)
+
+  def dec(self, n=1):
+    self.update(-n)
+
+  def update(self, value):
+    if cython.compiled:
+      ivalue = value
+      # We hold the GIL, no need for another lock.
+      self.value += ivalue
+    else:
+      with self._lock:
+        self.value += value
 
   def get_cumulative(self):
     with self._lock:
       return self.value
 
-  def to_runner_api_monitoring_info(self):
-    """Returns a Metric with this counter value for use in a MonitoringInfo."""
-    # TODO(ajamato): Update this code to be consistent with Gauges
-    # and Distributions. Since there is no CounterData class this method
-    # was added to CounterCell. Consider adding a CounterData class or
-    # removing the GaugeData and DistributionData classes.
-    return metrics_pb2.Metric(
-        counter_data=metrics_pb2.CounterData(
-            int64_value=self.get_cumulative()
-        )
-    )
+  def to_runner_api_user_metric(self, metric_name):
+    return beam_fn_api_pb2.Metrics.User(
+        metric_name=metric_name.to_runner_api(),
+        counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
+            value=self.value))
+
+  def to_runner_api_monitoring_info(self, name, transform_id):
+    from apache_beam.metrics import monitoring_infos
+    return monitoring_infos.int64_user_counter(
+        name.namespace, name.name,
+        metrics_pb2.Metric(
+            counter_data=metrics_pb2.CounterData(
+                int64_value=self.get_cumulative())),
+        ptransform=transform_id)
 
 
-class DistributionCell(Distribution, MetricCell):
+class DistributionCell(MetricCell):
   """For internal use only; no backwards-compatibility guarantees.
 
   Tracks the current value and delta for a distribution metric.
@@ -124,26 +148,43 @@
     return result
 
   def update(self, value):
-    with self._lock:
+    if cython.compiled:
+      # We will hold the GIL throughout the entire _update.
       self._update(value)
+    else:
+      with self._lock:
+        self._update(value)
 
   def _update(self, value):
-    value = int(value)
-    self.data.count += 1
-    self.data.sum += value
-    self.data.min = (value
-                     if self.data.min is None or self.data.min > value
-                     else self.data.min)
-    self.data.max = (value
-                     if self.data.max is None or self.data.max < value
-                     else self.data.max)
+    if cython.compiled:
+      ivalue = value
+    else:
+      ivalue = int(value)
+    self.data.count = self.data.count + 1
+    self.data.sum = self.data.sum + ivalue
+    if ivalue < self.data.min:
+      self.data.min = ivalue
+    if ivalue > self.data.max:
+      self.data.max = ivalue
 
   def get_cumulative(self):
     with self._lock:
       return self.data.get_cumulative()
 
+  def to_runner_api_user_metric(self, metric_name):
+    return beam_fn_api_pb2.Metrics.User(
+        metric_name=metric_name.to_runner_api(),
+        distribution_data=self.get_cumulative().to_runner_api())
 
-class GaugeCell(Gauge, MetricCell):
+  def to_runner_api_monitoring_info(self, name, transform_id):
+    from apache_beam.metrics import monitoring_infos
+    return monitoring_infos.int64_user_distribution(
+        name.namespace, name.name,
+        self.get_cumulative().to_runner_api_monitoring_info(),
+        ptransform=transform_id)
+
+
+class GaugeCell(MetricCell):
   """For internal use only; no backwards-compatibility guarantees.
 
   Tracks the current value and delta for a gauge metric.
@@ -167,6 +208,9 @@
     return result
 
   def set(self, value):
+    self.update(value)
+
+  def update(self, value):
     value = int(value)
     with self._lock:
       # Set the value directly without checking timestamp, because
@@ -178,6 +222,18 @@
     with self._lock:
       return self.data.get_cumulative()
 
+  def to_runner_api_user_metric(self, metric_name):
+    return beam_fn_api_pb2.Metrics.User(
+        metric_name=metric_name.to_runner_api(),
+        gauge_data=self.get_cumulative().to_runner_api())
+
+  def to_runner_api_monitoring_info(self, name, transform_id):
+    from apache_beam.metrics import monitoring_infos
+    return monitoring_infos.int64_user_gauge(
+        name.namespace, name.name,
+        self.get_cumulative().to_runner_api_monitoring_info(),
+        ptransform=transform_id)
+
 
 class DistributionResult(object):
   """The result of a Distribution metric."""
@@ -198,7 +254,7 @@
     return not self == other
 
   def __repr__(self):
-    return '<DistributionResult(sum={}, count={}, min={}, max={})>'.format(
+    return 'DistributionResult(sum={}, count={}, min={}, max={})'.format(
         self.sum,
         self.count,
         self.min,
@@ -206,11 +262,11 @@
 
   @property
   def max(self):
-    return self.data.max
+    return self.data.max if self.data.count else None
 
   @property
   def min(self):
-    return self.data.min
+    return self.data.min if self.data.count else None
 
   @property
   def count(self):
@@ -340,10 +396,15 @@
   by other than the DistributionCell that contains it.
   """
   def __init__(self, sum, count, min, max):
-    self.sum = sum
-    self.count = count
-    self.min = min
-    self.max = max
+    if count:
+      self.sum = sum
+      self.count = count
+      self.min = min
+      self.max = max
+    else:
+      self.sum = self.count = 0
+      self.min = 2**63 - 1
+      self.max = -2**63
 
   def __eq__(self, other):
     return (self.sum == other.sum and
@@ -359,7 +420,7 @@
     return not self == other
 
   def __repr__(self):
-    return '<DistributionData(sum={}, count={}, min={}, max={})>'.format(
+    return 'DistributionData(sum={}, count={}, min={}, max={})'.format(
         self.sum,
         self.count,
         self.min,
@@ -372,15 +433,11 @@
     if other is None:
       return self
 
-    new_min = (None if self.min is None and other.min is None else
-               min(x for x in (self.min, other.min) if x is not None))
-    new_max = (None if self.max is None and other.max is None else
-               max(x for x in (self.max, other.max) if x is not None))
     return DistributionData(
         self.sum + other.sum,
         self.count + other.count,
-        new_min,
-        new_max)
+        self.min if self.min < other.min else other.min,
+        self.max if self.max > other.max else other.max)
 
   @staticmethod
   def singleton(value):
@@ -449,7 +506,7 @@
   """
   @staticmethod
   def identity_element():
-    return DistributionData(0, 0, None, None)
+    return DistributionData(0, 0, 2**63 - 1, -2**63)
 
   def combine(self, x, y):
     return x.combine(y)
diff --git a/sdks/python/apache_beam/metrics/execution.pxd b/sdks/python/apache_beam/metrics/execution.pxd
index 74b34fb..6e1cbb0 100644
--- a/sdks/python/apache_beam/metrics/execution.pxd
+++ b/sdks/python/apache_beam/metrics/execution.pxd
@@ -16,10 +16,30 @@
 #
 
 cimport cython
+cimport libc.stdint
+
+from apache_beam.metrics.cells cimport MetricCell
+
+
+cdef object get_current_tracker
+
+
+cdef class _TypedMetricName(object):
+  cdef readonly object cell_type
+  cdef readonly object metric_name
+  cdef readonly object fast_name
+  cdef libc.stdint.int64_t _hash
+
+
+cdef object _DEFAULT
+
+
+cdef class MetricUpdater(object):
+  cdef _TypedMetricName typed_metric_name
+  cdef object default
 
 
 cdef class MetricsContainer(object):
   cdef object step_name
-  cdef public object counters
-  cdef public object distributions
-  cdef public object gauges
+  cdef public dict metrics
+  cpdef MetricCell get_metric_cell(self, metric_key)
diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py
index 91fe2f8..6918914 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -33,14 +33,13 @@
 from __future__ import absolute_import
 
 from builtins import object
-from collections import defaultdict
 
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.cells import CounterCell
 from apache_beam.metrics.cells import DistributionCell
 from apache_beam.metrics.cells import GaugeCell
-from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.runners.worker import statesampler
+from apache_beam.runners.worker.statesampler import get_current_tracker
 
 
 class MetricKey(object):
@@ -150,88 +149,117 @@
 MetricsEnvironment = _MetricsEnvironment()
 
 
+class _TypedMetricName(object):
+  """Like MetricName, but also stores the cell type of the metric."""
+  def __init__(self, cell_type, metric_name):
+    self.cell_type = cell_type
+    self.metric_name = metric_name
+    if isinstance(metric_name, str):
+      self.fast_name = metric_name
+    else:
+      self.fast_name = '%d_%s%s' % (
+          len(metric_name.name), metric_name.name, metric_name.namespace)
+    # Cached for speed, as this is used as a key for every counter update.
+    self._hash = hash((cell_type, self.fast_name))
+
+  def __eq__(self, other):
+    return self is other or (
+        self.cell_type == other.cell_type and self.fast_name == other.fast_name)
+
+  def __ne__(self, other):
+    return not self == other
+
+  def __hash__(self):
+    return self._hash
+
+  def __reduce__(self):
+    return _TypedMetricName, (self.cell_type, self.metric_name)
+
+
+_DEFAULT = None
+
+
+class MetricUpdater(object):
+  """A callable that updates the metric as quickly as possible."""
+  def __init__(self, cell_type, metric_name, default=None):
+    self.typed_metric_name = _TypedMetricName(cell_type, metric_name)
+    self.default = default
+
+  def __call__(self, value=_DEFAULT):
+    if value is _DEFAULT:
+      if self.default is _DEFAULT:
+        raise ValueError(
+            'Missing value for update of %s' % self.metric_name)
+      value = self.default
+    tracker = get_current_tracker()
+    if tracker is not None:
+      tracker.update_metric(self.typed_metric_name, value)
+
+  def __reduce__(self):
+    return MetricUpdater, (
+        self.typed_metric_name.cell_type,
+        self.typed_metric_name.metric_name,
+        self.default)
+
+
 class MetricsContainer(object):
   """Holds the metrics of a single step and a single bundle."""
   def __init__(self, step_name):
     self.step_name = step_name
-    self.counters = defaultdict(lambda: CounterCell())
-    self.distributions = defaultdict(lambda: DistributionCell())
-    self.gauges = defaultdict(lambda: GaugeCell())
+    self.metrics = dict()
 
   def get_counter(self, metric_name):
-    return self.counters[metric_name]
+    return self.get_metric_cell(_TypedMetricName(CounterCell, metric_name))
 
   def get_distribution(self, metric_name):
-    return self.distributions[metric_name]
+    return self.get_metric_cell(_TypedMetricName(DistributionCell, metric_name))
 
   def get_gauge(self, metric_name):
-    return self.gauges[metric_name]
+    return self.get_metric_cell(_TypedMetricName(GaugeCell, metric_name))
+
+  def get_metric_cell(self, typed_metric_name):
+    cell = self.metrics.get(typed_metric_name, None)
+    if cell is None:
+      cell = self.metrics[typed_metric_name] = typed_metric_name.cell_type()
+    return cell
 
   def get_cumulative(self):
     """Return MetricUpdates with cumulative values of all metrics in container.
 
     This returns all the cumulative values for all metrics.
     """
-    counters = {MetricKey(self.step_name, k): v.get_cumulative()
-                for k, v in self.counters.items()}
+    counters = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+                for k, v in self.metrics.items()
+                if k.cell_type == CounterCell}
 
-    distributions = {MetricKey(self.step_name, k): v.get_cumulative()
-                     for k, v in self.distributions.items()}
+    distributions = {
+        MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+        for k, v in self.metrics.items()
+        if k.cell_type == DistributionCell}
 
-    gauges = {MetricKey(self.step_name, k): v.get_cumulative()
-              for k, v in self.gauges.items()}
+    gauges = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+              for k, v in self.metrics.items()
+              if k.cell_type == GaugeCell}
 
     return MetricUpdates(counters, distributions, gauges)
 
   def to_runner_api(self):
-    return (
-        [beam_fn_api_pb2.Metrics.User(
-            metric_name=k.to_runner_api(),
-            counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
-                value=v.get_cumulative()))
-         for k, v in self.counters.items()] +
-        [beam_fn_api_pb2.Metrics.User(
-            metric_name=k.to_runner_api(),
-            distribution_data=v.get_cumulative().to_runner_api())
-         for k, v in self.distributions.items()] +
-        [beam_fn_api_pb2.Metrics.User(
-            metric_name=k.to_runner_api(),
-            gauge_data=v.get_cumulative().to_runner_api())
-         for k, v in self.gauges.items()]
-    )
+    return [cell.to_runner_api_user_metric(key.metric_name)
+            for key, cell in self.metrics.items()]
 
   def to_runner_api_monitoring_infos(self, transform_id):
     """Returns a list of MonitoringInfos for the metrics in this container."""
-    all_user_metrics = []
-    for k, v in self.counters.items():
-      all_user_metrics.append(monitoring_infos.int64_user_counter(
-          k.namespace, k.name,
-          v.to_runner_api_monitoring_info(),
-          ptransform=transform_id
-      ))
-
-    for k, v in self.distributions.items():
-      all_user_metrics.append(monitoring_infos.int64_user_distribution(
-          k.namespace, k.name,
-          v.get_cumulative().to_runner_api_monitoring_info(),
-          ptransform=transform_id
-      ))
-
-    for k, v in self.gauges.items():
-      all_user_metrics.append(monitoring_infos.int64_user_gauge(
-          k.namespace, k.name,
-          v.get_cumulative().to_runner_api_monitoring_info(),
-          ptransform=transform_id
-      ))
+    all_user_metrics = [
+        cell.to_runner_api_monitoring_info(key.metric_name, transform_id)
+        for key, cell in self.metrics.items()]
     return {monitoring_infos.to_key(mi) : mi for mi in all_user_metrics}
 
   def reset(self):
-    for counter in self.counters.values():
-      counter.reset()
-    for distribution in self.distributions.values():
-      distribution.reset()
-    for gauge in self.gauges.values():
-      gauge.reset()
+    for metric in self.metrics.values():
+      metric.reset()
+
+  def __reduce__(self):
+    raise NotImplementedError
 
 
 class MetricUpdates(object):
diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py
index 9af1696..fc363a4 100644
--- a/sdks/python/apache_beam/metrics/execution_test.py
+++ b/sdks/python/apache_beam/metrics/execution_test.py
@@ -73,12 +73,6 @@
 
 
 class TestMetricsContainer(unittest.TestCase):
-  def test_create_new_counter(self):
-    mc = MetricsContainer('astep')
-    self.assertFalse(MetricName('namespace', 'name') in mc.counters)
-    mc.get_counter(MetricName('namespace', 'name'))
-    self.assertTrue(MetricName('namespace', 'name') in mc.counters)
-
   def test_add_to_counter(self):
     mc = MetricsContainer('astep')
     counter = mc.get_counter(MetricName('namespace', 'name'))
diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py
index acd4771..8bbe191 100644
--- a/sdks/python/apache_beam/metrics/metric.py
+++ b/sdks/python/apache_beam/metrics/metric.py
@@ -29,7 +29,8 @@
 import inspect
 from builtins import object
 
-from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics import cells
+from apache_beam.metrics.execution import MetricUpdater
 from apache_beam.metrics.metricbase import Counter
 from apache_beam.metrics.metricbase import Distribution
 from apache_beam.metrics.metricbase import Gauge
@@ -101,11 +102,7 @@
     def __init__(self, metric_name):
       super(Metrics.DelegatingCounter, self).__init__()
       self.metric_name = metric_name
-
-    def inc(self, n=1):
-      container = MetricsEnvironment.current_container()
-      if container is not None:
-        container.get_counter(self.metric_name).inc(n)
+      self.inc = MetricUpdater(cells.CounterCell, metric_name, default=1)
 
   class DelegatingDistribution(Distribution):
     """Metrics Distribution Delegates functionality to MetricsEnvironment."""
@@ -113,11 +110,7 @@
     def __init__(self, metric_name):
       super(Metrics.DelegatingDistribution, self).__init__()
       self.metric_name = metric_name
-
-    def update(self, value):
-      container = MetricsEnvironment.current_container()
-      if container is not None:
-        container.get_distribution(self.metric_name).update(value)
+      self.update = MetricUpdater(cells.DistributionCell, metric_name)
 
   class DelegatingGauge(Gauge):
     """Metrics Gauge that Delegates functionality to MetricsEnvironment."""
@@ -125,11 +118,7 @@
     def __init__(self, metric_name):
       super(Metrics.DelegatingGauge, self).__init__()
       self.metric_name = metric_name
-
-    def set(self, value):
-      container = MetricsEnvironment.current_container()
-      if container is not None:
-        container.get_gauge(self.metric_name).set(value)
+      self.set = MetricUpdater(cells.GaugeCell, metric_name)
 
 
 class MetricResults(object):
diff --git a/sdks/python/apache_beam/metrics/metric_test.py b/sdks/python/apache_beam/metrics/metric_test.py
index 6e8ee08..cb18dc7 100644
--- a/sdks/python/apache_beam/metrics/metric_test.py
+++ b/sdks/python/apache_beam/metrics/metric_test.py
@@ -130,31 +130,36 @@
     statesampler.set_current_tracker(sampler)
     state1 = sampler.scoped_state('mystep', 'myState',
                                   metrics_container=MetricsContainer('mystep'))
-    sampler.start()
-    with state1:
-      counter_ns = 'aCounterNamespace'
-      distro_ns = 'aDistributionNamespace'
-      name = 'a_name'
-      counter = Metrics.counter(counter_ns, name)
-      distro = Metrics.distribution(distro_ns, name)
-      counter.inc(10)
-      counter.dec(3)
-      distro.update(10)
-      distro.update(2)
-      self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
-      self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
 
-      del distro
-      del counter
+    try:
+      sampler.start()
+      with state1:
+        counter_ns = 'aCounterNamespace'
+        distro_ns = 'aDistributionNamespace'
+        name = 'a_name'
+        counter = Metrics.counter(counter_ns, name)
+        distro = Metrics.distribution(distro_ns, name)
+        counter.inc(10)
+        counter.dec(3)
+        distro.update(10)
+        distro.update(2)
+        self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
+        self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
 
-      container = MetricsEnvironment.current_container()
-      self.assertEqual(
-          container.counters[MetricName(counter_ns, name)].get_cumulative(),
-          7)
-      self.assertEqual(
-          container.distributions[MetricName(distro_ns, name)].get_cumulative(),
-          DistributionData(12, 2, 2, 10))
-    sampler.stop()
+        del distro
+        del counter
+
+        container = MetricsEnvironment.current_container()
+        self.assertEqual(
+            container.get_counter(
+                MetricName(counter_ns, name)).get_cumulative(),
+            7)
+        self.assertEqual(
+            container.get_distribution(
+                MetricName(distro_ns, name)).get_cumulative(),
+            DistributionData(12, 2, 2, 10))
+    finally:
+      sampler.stop()
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/portability/spark_runner.py b/sdks/python/apache_beam/runners/portability/spark_runner.py
new file mode 100644
index 0000000..ca03310
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/spark_runner.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A runner for executing portable pipelines on Spark."""
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+import re
+
+from apache_beam.options import pipeline_options
+from apache_beam.runners.portability import job_server
+from apache_beam.runners.portability import portable_runner
+
+# https://spark.apache.org/docs/latest/submitting-applications.html#master-urls
+LOCAL_MASTER_PATTERN = r'^local(\[.+\])?$'
+
+
+class SparkRunner(portable_runner.PortableRunner):
+  def run_pipeline(self, pipeline, options):
+    spark_options = options.view_as(SparkRunnerOptions)
+    portable_options = options.view_as(pipeline_options.PortableOptions)
+    if (re.match(LOCAL_MASTER_PATTERN, spark_options.spark_master_url)
+        and not portable_options.environment_type
+        and not portable_options.output_executable_path):
+      portable_options.environment_type = 'LOOPBACK'
+    return super(SparkRunner, self).run_pipeline(pipeline, options)
+
+  def default_job_server(self, options):
+    # TODO(BEAM-8139) submit a Spark jar to a cluster
+    return job_server.StopOnExitJobServer(SparkJarJobServer(options))
+
+
+class SparkRunnerOptions(pipeline_options.PipelineOptions):
+  @classmethod
+  def _add_argparse_args(cls, parser):
+    parser.add_argument('--spark_master_url',
+                        default='local[4]',
+                        help='Spark master URL (spark://HOST:PORT). '
+                             'Use "local" (single-threaded) or "local[*]" '
+                             '(multi-threaded) to start a local cluster for '
+                             'the execution.')
+    parser.add_argument('--spark_job_server_jar',
+                        help='Path or URL to a Beam Spark jobserver jar.')
+    parser.add_argument('--artifacts_dir', default=None)
+
+
+class SparkJarJobServer(job_server.JavaJarJobServer):
+  def __init__(self, options):
+    super(SparkJarJobServer, self).__init__()
+    options = options.view_as(SparkRunnerOptions)
+    self._jar = options.spark_job_server_jar
+    self._master_url = options.spark_master_url
+    self._artifacts_dir = options.artifacts_dir
+
+  def path_to_jar(self):
+    if self._jar:
+      return self._jar
+    else:
+      return self.path_to_beam_jar('runners:spark:job-server:shadowJar')
+
+  def java_arguments(self, job_port, artifacts_dir):
+    return [
+        '--spark-master-url', self._master_url,
+        '--artifacts-dir', (self._artifacts_dir
+                            if self._artifacts_dir else artifacts_dir),
+        '--job-port', job_port,
+        '--artifact-port', 0,
+        '--expansion-port', 0
+    ]
diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py
index c17ab08..fe9c492 100644
--- a/sdks/python/apache_beam/runners/runner.py
+++ b/sdks/python/apache_beam/runners/runner.py
@@ -38,6 +38,7 @@
     'apache_beam.runners.interactive.interactive_runner.InteractiveRunner',
     'apache_beam.runners.portability.flink_runner.FlinkRunner',
     'apache_beam.runners.portability.portable_runner.PortableRunner',
+    'apache_beam.runners.portability.spark_runner.SparkRunner',
     'apache_beam.runners.test.TestDirectRunner',
     'apache_beam.runners.test.TestDataflowRunner',
 )
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
index 799bd0d..aebf9f6 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
@@ -43,6 +43,9 @@
 
   cdef int32_t current_state_index
 
+  cpdef ScopedState current_state(self)
+  cdef inline ScopedState current_state_c(self)
+
   cpdef _scoped_state(
       self, counter_name, name_context, output_counter, metrics_container)
 
@@ -56,7 +59,7 @@
   cdef readonly object name_context
   cdef readonly int64_t _nsecs
   cdef int32_t old_state_index
-  cdef readonly MetricsContainer _metrics_container
+  cdef readonly MetricsContainer metrics_container
 
   cpdef __enter__(self)
 
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 325ec99..8d2346a 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -159,8 +159,12 @@
       (<ScopedState>state)._nsecs = 0
     self.started = self.finished = False
 
-  def current_state(self):
-    return self.scoped_states_by_index[self.current_state_index]
+  cpdef ScopedState current_state(self):
+    return self.current_state_c()
+
+  cdef inline ScopedState current_state_c(self):
+    # Faster than cpdef due to self always being a Python subclass.
+    return <ScopedState>self.scoped_states_by_index[self.current_state_index]
 
   cpdef _scoped_state(self, counter_name, name_context, output_counter,
                       metrics_container):
@@ -189,6 +193,11 @@
     pythread.PyThread_release_lock(self.lock)
     return scoped_state
 
+  def update_metric(self, typed_metric_name, value):
+    # Each of these is a cdef lookup.
+    self.current_state_c().metrics_container.get_metric_cell(
+        typed_metric_name).update(value)
+
 
 cdef class ScopedState(object):
   """Context manager class managing transitions for a given sampler state."""
@@ -205,7 +214,7 @@
     self.name_context = step_name_context
     self.state_index = state_index
     self.counter = counter
-    self._metrics_container = metrics_container
+    self.metrics_container = metrics_container
 
   @property
   def nsecs(self):
@@ -232,7 +241,3 @@
     self.sampler.current_state_index = self.old_state_index
     self.sampler.state_transition_count += 1
     pythread.PyThread_release_lock(self.sampler.lock)
-
-  @property
-  def metrics_container(self):
-    return self._metrics_container
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
index 0091828..fb2592c 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
@@ -50,6 +50,10 @@
     return ScopedState(
         self, counter_name, name_context, output_counter, metrics_container)
 
+  def update_metric(self, typed_metric_name, value):
+    self.current_state().metrics_container.get_metric_cell(
+        typed_metric_name).update(value)
+
   def _enter_state(self, state):
     self.state_transition_count += 1
     self._state_stack.append(state)
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index ccf90f6..7eea64c 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -214,6 +214,7 @@
     ext_modules=cythonize([
         'apache_beam/**/*.pyx',
         'apache_beam/coders/coder_impl.py',
+        'apache_beam/metrics/cells.py',
         'apache_beam/metrics/execution.py',
         'apache_beam/runners/common.py',
         'apache_beam/runners/worker/logger.py',
diff --git a/sdks/python/test-suites/portable/py2/build.gradle b/sdks/python/test-suites/portable/py2/build.gradle
index 3c1548d..5d967e4 100644
--- a/sdks/python/test-suites/portable/py2/build.gradle
+++ b/sdks/python/test-suites/portable/py2/build.gradle
@@ -39,6 +39,8 @@
   dependsOn ':runners:flink:1.9:job-server:shadowJar'
   dependsOn portableWordCountFlinkRunnerBatch
   dependsOn portableWordCountFlinkRunnerStreaming
+  dependsOn ':runners:spark:job-server:shadowJar'
+  dependsOn portableWordCountSparkRunnerBatch
 }
 
 // TODO: Move the rest of this file into ../common.gradle.
diff --git a/sdks/python/test-suites/portable/py35/build.gradle b/sdks/python/test-suites/portable/py35/build.gradle
index 1b2cb4f..88b4e2f 100644
--- a/sdks/python/test-suites/portable/py35/build.gradle
+++ b/sdks/python/test-suites/portable/py35/build.gradle
@@ -36,4 +36,6 @@
     dependsOn ':runners:flink:1.9:job-server:shadowJar'
     dependsOn portableWordCountFlinkRunnerBatch
     dependsOn portableWordCountFlinkRunnerStreaming
+    dependsOn ':runners:spark:job-server:shadowJar'
+    dependsOn portableWordCountSparkRunnerBatch
 }
diff --git a/sdks/python/test-suites/portable/py36/build.gradle b/sdks/python/test-suites/portable/py36/build.gradle
index 475e110..496777d 100644
--- a/sdks/python/test-suites/portable/py36/build.gradle
+++ b/sdks/python/test-suites/portable/py36/build.gradle
@@ -36,4 +36,6 @@
     dependsOn ':runners:flink:1.9:job-server:shadowJar'
     dependsOn portableWordCountFlinkRunnerBatch
     dependsOn portableWordCountFlinkRunnerStreaming
+    dependsOn ':runners:spark:job-server:shadowJar'
+    dependsOn portableWordCountSparkRunnerBatch
 }
diff --git a/sdks/python/test-suites/portable/py37/build.gradle b/sdks/python/test-suites/portable/py37/build.gradle
index 912b316..924de81 100644
--- a/sdks/python/test-suites/portable/py37/build.gradle
+++ b/sdks/python/test-suites/portable/py37/build.gradle
@@ -36,4 +36,6 @@
     dependsOn ':runners:flink:1.9:job-server:shadowJar'
     dependsOn portableWordCountFlinkRunnerBatch
     dependsOn portableWordCountFlinkRunnerStreaming
+    dependsOn ':runners:spark:job-server:shadowJar'
+    dependsOn portableWordCountSparkRunnerBatch
 }
diff --git a/website/Gemfile b/website/Gemfile
index 4a08725..1050303 100644
--- a/website/Gemfile
+++ b/website/Gemfile
@@ -20,7 +20,7 @@
 
 source 'https://rubygems.org'
 
-gem 'jekyll', '3.2'
+gem 'jekyll', '3.6.3'
 
 # Jekyll plugins
 group :jekyll_plugins do
diff --git a/website/Gemfile.lock b/website/Gemfile.lock
index e94f132..9db2ebe 100644
--- a/website/Gemfile.lock
+++ b/website/Gemfile.lock
@@ -13,7 +13,7 @@
     concurrent-ruby (1.1.4)
     ethon (0.11.0)
       ffi (>= 1.3.0)
-    ffi (1.9.25)
+    ffi (1.11.1)
     forwardable-extended (2.6.0)
     html-proofer (3.9.3)
       activesupport (>= 4.2, < 6.0)
@@ -26,15 +26,16 @@
       yell (~> 2.0)
     i18n (0.9.5)
       concurrent-ruby (~> 1.0)
-    jekyll (3.2.0)
+    jekyll (3.6.3)
+      addressable (~> 2.4)
       colorator (~> 1.0)
       jekyll-sass-converter (~> 1.0)
       jekyll-watch (~> 1.1)
-      kramdown (~> 1.3)
-      liquid (~> 3.0)
+      kramdown (~> 1.14)
+      liquid (~> 4.0)
       mercenary (~> 0.3.3)
       pathutil (~> 0.9)
-      rouge (~> 1.7)
+      rouge (>= 1.7, < 3)
       safe_yaml (~> 1.0)
     jekyll-redirect-from (0.11.0)
       jekyll (>= 2.0)
@@ -45,29 +46,27 @@
     jekyll_github_sample (0.3.1)
       activesupport (~> 4.0)
       jekyll (~> 3.0)
-    kramdown (1.16.2)
-    liquid (3.0.6)
-    listen (3.1.5)
-      rb-fsevent (~> 0.9, >= 0.9.4)
-      rb-inotify (~> 0.9, >= 0.9.7)
-      ruby_dep (~> 1.2)
+    kramdown (1.17.0)
+    liquid (4.0.3)
+    listen (3.2.0)
+      rb-fsevent (~> 0.10, >= 0.10.3)
+      rb-inotify (~> 0.9, >= 0.9.10)
     mercenary (0.3.6)
     mini_portile2 (2.3.0)
     minitest (5.11.3)
     nokogiri (1.8.5)
       mini_portile2 (~> 2.3.0)
     parallel (1.12.1)
-    pathutil (0.16.1)
+    pathutil (0.16.2)
       forwardable-extended (~> 2.6)
     public_suffix (3.0.3)
     rake (12.3.0)
-    rb-fsevent (0.10.2)
-    rb-inotify (0.9.10)
-      ffi (>= 0.5.0, < 2)
-    rouge (1.11.1)
-    ruby_dep (1.5.0)
-    safe_yaml (1.0.4)
-    sass (3.5.5)
+    rb-fsevent (0.10.3)
+    rb-inotify (0.10.0)
+      ffi (~> 1.0)
+    rouge (2.2.1)
+    safe_yaml (1.0.5)
+    sass (3.7.4)
       sass-listen (~> 4.0.0)
     sass-listen (4.0.0)
       rb-fsevent (~> 0.9, >= 0.9.4)
@@ -85,7 +84,7 @@
 DEPENDENCIES
   activesupport (< 5.0.0.0)
   html-proofer
-  jekyll (= 3.2)
+  jekyll (= 3.6.3)
   jekyll-redirect-from
   jekyll-sass-converter
   jekyll_github_sample