Merge pull request #10076 from ibzib/java-default-region

[BEAM-8628] use mock GcsUtil in testDefaultGcpTempLocationDoesNotExist
diff --git a/.test-infra/jenkins/job_Dependency_Check.groovy b/.test-infra/jenkins/job_Dependency_Check.groovy
index ac66881..dddd2f7 100644
--- a/.test-infra/jenkins/job_Dependency_Check.groovy
+++ b/.test-infra/jenkins/job_Dependency_Check.groovy
@@ -38,7 +38,7 @@
   steps {
     gradle {
       rootBuildScriptDir(commonJobProperties.checkoutDir)
-      tasks(':runBeamDependencyCheck')
+      tasks('runBeamDependencyCheck')
       commonJobProperties.setGradleSwitches(delegate)
       switches('-Drevision=release')
     }
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 3ea3643..557f45f 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -1905,6 +1905,7 @@
           mustRunAfter = [
             ':runners:flink:1.9:job-server-container:docker',
             ':runners:flink:1.9:job-server:shadowJar',
+            ':runners:spark:job-server:shadowJar',
             ':sdks:python:container:py2:docker',
             ':sdks:python:container:py35:docker',
             ':sdks:python:container:py36:docker',
@@ -1958,6 +1959,7 @@
         addPortableWordCountTask(true, "PortableRunner")
         addPortableWordCountTask(false, "FlinkRunner")
         addPortableWordCountTask(true, "FlinkRunner")
+        addPortableWordCountTask(false, "SparkRunner")
       }
     }
   }
diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index ed2f013..0ddc48e 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -42,6 +42,7 @@
 import "endpoints.proto";
 import "google/protobuf/descriptor.proto";
 import "google/protobuf/timestamp.proto";
+import "google/protobuf/duration.proto";
 import "google/protobuf/wrappers.proto";
 import "metrics.proto";
 
@@ -203,13 +204,21 @@
 }
 
 // An Application should be scheduled for execution after a delay.
+// Either an absolute timestamp or a relative timestamp can represent a
+// scheduled execution time.
 message DelayedBundleApplication {
   // Recommended time at which the application should be scheduled to execute
   // by the runner. Times in the past may be scheduled to execute immediately.
+  // TODO(BEAM-8536): Migrate usage of absolute time to requested_time_delay.
   google.protobuf.Timestamp requested_execution_time = 1;
 
   // (Required) The application that should be scheduled.
   BundleApplication application = 2;
+
+  // Recommended time delay at which the application should be scheduled to
+  // execute by the runner. Time delay that equals 0 may be scheduled to execute
+  // immediately. The unit of time delay should be microsecond.
+  google.protobuf.Duration requested_time_delay = 3;
 }
 
 // A request to process a given bundle.
diff --git a/model/job-management/src/main/proto/beam_job_api.proto b/model/job-management/src/main/proto/beam_job_api.proto
index e9f0eb9..d297d3b 100644
--- a/model/job-management/src/main/proto/beam_job_api.proto
+++ b/model/job-management/src/main/proto/beam_job_api.proto
@@ -213,17 +213,40 @@
 // without needing to pass through STARTING.
 message JobState {
   enum Enum {
+    // The job state reported by a runner cannot be interpreted by the SDK.
     UNSPECIFIED = 0;
+
+    // The job has not yet started.
     STOPPED = 1;
+
+    // The job is currently running.
     RUNNING = 2;
+
+    // The job has successfully completed. (terminal)
     DONE = 3;
+
+    // The job has failed. (terminal)
     FAILED = 4;
+
+    // The job has been explicitly cancelled. (terminal)
     CANCELLED = 5;
+
+    // The job has been updated. (terminal)
     UPDATED = 6;
+
+    // The job is draining its data. (optional)
     DRAINING = 7;
+
+    // The job has completed draining its data. (terminal)
     DRAINED = 8;
+
+    // The job is starting up.
     STARTING = 9;
+
+    // The job is cancelling. (optional)
     CANCELLING = 10;
+
+    // The job is in the process of being updated. (optional)
     UPDATING = 11;
   }
 }
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java
index 1a6fee4..cecbee9 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java
@@ -50,11 +50,20 @@
               PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN,
               GreedyPCollectionFusers::canFuseParDo)
           .put(
+              PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN,
+              GreedyPCollectionFusers::canFuseParDo)
+          .put(
+              PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
+              GreedyPCollectionFusers::cannotFuse)
+          .put(
+              PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN,
+              GreedyPCollectionFusers::cannotFuse)
+          .put(
               PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN,
               GreedyPCollectionFusers::canFuseParDo)
           .put(
               PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN,
-              GreedyPCollectionFusers::canFuseParDo)
+              GreedyPCollectionFusers::cannotFuse)
           .put(
               PTransformTranslation.COMBINE_PER_KEY_PRECOMBINE_TRANSFORM_URN,
               GreedyPCollectionFusers::canFuseCompatibleEnvironment)
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java
index 5ea867e..cd28cbf 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java
@@ -21,6 +21,7 @@
 
 import java.util.List;
 import java.util.Map;
+import javax.annotation.Nullable;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ComponentsOrBuilder;
 import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
@@ -32,10 +33,9 @@
 /**
  * A way to apply a Proto-based {@link PTransformOverride}.
  *
- * <p>This should generally be used to replace runner-executed transforms with runner-executed
- * composites and simpler runner-executed primitives. It is generically less powerful than the
- * native {@link org.apache.beam.sdk.Pipeline#replaceAll(List)} and more error-prone, so should only
- * be used for relatively simple replacements.
+ * <p>This should generally be used by runners to replace transforms within graphs. SDK construction
+ * code should rely on the more powerful and native {@link
+ * org.apache.beam.sdk.Pipeline#replaceAll(List)}.
  */
 @Experimental
 public class ProtoOverrides {
@@ -51,6 +51,10 @@
       if (pt.getValue().getSpec() != null && urn.equals(pt.getValue().getSpec().getUrn())) {
         MessageWithComponents updated =
             compositeBuilder.getReplacement(pt.getKey(), originalPipeline.getComponents());
+        if (updated == null) {
+          continue;
+        }
+
         checkArgument(
             updated.getPtransform().getOutputsMap().equals(pt.getValue().getOutputsMap()),
             "A %s must produce all of the outputs of the original %s",
@@ -66,8 +70,8 @@
   }
 
   /**
-   * Remove all subtransforms of the provided transform recursively.A {@link PTransform} can be the
-   * subtransform of only one enclosing transform.
+   * Remove all sub-transforms of the provided transform recursively. A {@link PTransform} can be
+   * the sub-transform of only one enclosing transform.
    */
   private static void removeSubtransforms(PTransform pt, Components.Builder target) {
     for (String subtransformId : pt.getSubtransformsList()) {
@@ -87,14 +91,16 @@
     /**
      * Returns the updated composite structure for the provided {@link PTransform}.
      *
-     * <p>The returned {@link MessageWithComponents} must contain a single {@link PTransform}. The
-     * result {@link Components} will be merged into the existing components, and the result {@link
-     * PTransform} will be set as a replacement of the original {@link PTransform}. Notably, this
-     * does not require that the {@code existingComponents} are present in the returned {@link
+     * <p>If the return is null, then no replacement is performed, otherwise the returned {@link
+     * MessageWithComponents} must contain a single {@link PTransform}. The result {@link
+     * Components} will be merged into the existing components, and the result {@link PTransform}
+     * will be set as a replacement of the original {@link PTransform}. Notably, this does not
+     * require that the {@code existingComponents} are present in the returned {@link
      * MessageWithComponents}.
      *
      * <p>Introduced components must not collide with any components in the existing components.
      */
+    @Nullable
     MessageWithComponents getReplacement(
         String transformId, ComponentsOrBuilder existingComponents);
   }
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java
index 382f68a..4ed19da 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java
@@ -28,10 +28,12 @@
 import static org.apache.beam.runners.core.construction.PTransformTranslation.MAP_WINDOWS_TRANSFORM_URN;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.READ_TRANSFORM_URN;
+import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN;
+import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.TEST_STREAM_TRANSFORM_URN;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 
@@ -174,6 +176,8 @@
           COMBINE_PER_KEY_PRECOMBINE_TRANSFORM_URN,
           COMBINE_PER_KEY_MERGE_ACCUMULATORS_TRANSFORM_URN,
           COMBINE_PER_KEY_EXTRACT_OUTPUTS_TRANSFORM_URN,
+          SPLITTABLE_PAIR_WITH_RESTRICTION_URN,
+          SPLITTABLE_SPLIT_RESTRICTION_URN,
           SPLITTABLE_PROCESS_KEYED_URN,
           SPLITTABLE_PROCESS_ELEMENTS_URN,
           SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN,
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpander.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpander.java
new file mode 100644
index 0000000..77f0211
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpander.java
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction.graph;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.function.Predicate;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ComponentsOrBuilder;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.runners.core.construction.ModelCoders;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides.TransformReplacement;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+
+/**
+ * A set of transform replacements for expanding a splittable ParDo into various sub components.
+ *
+ * <p>Further details about the expansion can be found at <a
+ * href="https://github.com/apache/beam/blob/cb15994d5228f729dda922419b08520c8be8804e/model/pipeline/src/main/proto/beam_runner_api.proto#L279"
+ * />
+ */
+public class SplittableParDoExpander {
+
+  /**
+   * Returns a transform replacement which expands a splittable ParDo from:
+   *
+   * <pre>{@code
+   * sideInputA ---------\
+   * sideInputB ---------V
+   * mainInput ---> SplittableParDo --> outputA
+   *                                \-> outputB
+   * }</pre>
+   *
+   * into:
+   *
+   * <pre>{@code
+   * sideInputA ---------\---------------------\--------------------------\
+   * sideInputB ---------V---------------------V--------------------------V
+   * mainInput ---> PairWithRestricton --> SplitAndSize --> ProcessSizedElementsAndRestriction --> outputA
+   *                                                                                           \-> outputB
+   * }</pre>
+   *
+   * <p>Specifically this transform ensures that initial splitting is performed and that the sizing
+   * information is available to the runner if it chooses to inspect it.
+   */
+  public static TransformReplacement createSizedReplacement() {
+    return SizedReplacement.INSTANCE;
+  }
+
+  /** See {@link #createSizedReplacement()} for details. */
+  private static class SizedReplacement implements TransformReplacement {
+
+    private static final SizedReplacement INSTANCE = new SizedReplacement();
+
+    @Override
+    public MessageWithComponents getReplacement(
+        String transformId, ComponentsOrBuilder existingComponents) {
+      try {
+        MessageWithComponents.Builder rval = MessageWithComponents.newBuilder();
+
+        PTransform splittableParDo = existingComponents.getTransformsOrThrow(transformId);
+        ParDoPayload payload = ParDoPayload.parseFrom(splittableParDo.getSpec().getPayload());
+        // Only perform the expansion if this is a splittable DoFn.
+        if (payload.getRestrictionCoderId() == null || payload.getRestrictionCoderId().isEmpty()) {
+          return null;
+        }
+
+        String mainInputName = ParDoTranslation.getMainInputName(splittableParDo);
+        String mainInputPCollectionId = splittableParDo.getInputsOrThrow(mainInputName);
+        PCollection mainInputPCollection =
+            existingComponents.getPcollectionsOrThrow(mainInputPCollectionId);
+        Map<String, String> sideInputs =
+            Maps.filterKeys(
+                splittableParDo.getInputsMap(), input -> payload.containsSideInputs(input));
+
+        String pairWithRestrictionOutCoderId =
+            generateUniqueId(
+                mainInputPCollection.getCoderId() + "/PairWithRestriction",
+                existingComponents::containsCoders);
+        rval.getComponentsBuilder()
+            .putCoders(
+                pairWithRestrictionOutCoderId,
+                ModelCoders.kvCoder(
+                    mainInputPCollection.getCoderId(), payload.getRestrictionCoderId()));
+
+        String pairWithRestrictionOutId =
+            generateUniqueId(
+                mainInputPCollectionId + "/PairWithRestriction",
+                existingComponents::containsPcollections);
+        rval.getComponentsBuilder()
+            .putPcollections(
+                pairWithRestrictionOutId,
+                PCollection.newBuilder()
+                    .setCoderId(pairWithRestrictionOutCoderId)
+                    .setIsBounded(mainInputPCollection.getIsBounded())
+                    .setWindowingStrategyId(mainInputPCollection.getWindowingStrategyId())
+                    .setUniqueName(
+                        generateUniquePCollectonName(
+                            mainInputPCollection.getUniqueName() + "/PairWithRestriction",
+                            existingComponents))
+                    .build());
+
+        String splitAndSizeOutCoderId =
+            generateUniqueId(
+                mainInputPCollection.getCoderId() + "/SplitAndSize",
+                existingComponents::containsCoders);
+        rval.getComponentsBuilder()
+            .putCoders(
+                splitAndSizeOutCoderId,
+                ModelCoders.kvCoder(
+                    pairWithRestrictionOutCoderId, getOrAddDoubleCoder(existingComponents, rval)));
+
+        String splitAndSizeOutId =
+            generateUniqueId(
+                mainInputPCollectionId + "/SplitAndSize", existingComponents::containsPcollections);
+        rval.getComponentsBuilder()
+            .putPcollections(
+                splitAndSizeOutId,
+                PCollection.newBuilder()
+                    .setCoderId(splitAndSizeOutCoderId)
+                    .setIsBounded(mainInputPCollection.getIsBounded())
+                    .setWindowingStrategyId(mainInputPCollection.getWindowingStrategyId())
+                    .setUniqueName(
+                        generateUniquePCollectonName(
+                            mainInputPCollection.getUniqueName() + "/SplitAndSize",
+                            existingComponents))
+                    .build());
+
+        String pairWithRestrictionId =
+            generateUniqueId(
+                transformId + "/PairWithRestriction", existingComponents::containsTransforms);
+        {
+          PTransform.Builder pairWithRestriction = PTransform.newBuilder();
+          pairWithRestriction.putAllInputs(splittableParDo.getInputsMap());
+          pairWithRestriction.putOutputs("out", pairWithRestrictionOutId);
+          pairWithRestriction.setUniqueName(
+              generateUniquePCollectonName(
+                  splittableParDo.getUniqueName() + "/PairWithRestriction", existingComponents));
+          pairWithRestriction.setSpec(
+              FunctionSpec.newBuilder()
+                  .setUrn(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
+                  .setPayload(splittableParDo.getSpec().getPayload()));
+          rval.getComponentsBuilder()
+              .putTransforms(pairWithRestrictionId, pairWithRestriction.build());
+        }
+
+        String splitAndSizeId =
+            generateUniqueId(transformId + "/SplitAndSize", existingComponents::containsTransforms);
+        {
+          PTransform.Builder splitAndSize = PTransform.newBuilder();
+          splitAndSize.putInputs(mainInputName, pairWithRestrictionOutId);
+          splitAndSize.putAllInputs(sideInputs);
+          splitAndSize.putOutputs("out", splitAndSizeOutId);
+          splitAndSize.setUniqueName(
+              generateUniquePCollectonName(
+                  splittableParDo.getUniqueName() + "/SplitAndSize", existingComponents));
+          splitAndSize.setSpec(
+              FunctionSpec.newBuilder()
+                  .setUrn(PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
+                  .setPayload(splittableParDo.getSpec().getPayload()));
+          rval.getComponentsBuilder().putTransforms(splitAndSizeId, splitAndSize.build());
+        }
+
+        String processSizedElementsAndRestrictionsId =
+            generateUniqueId(
+                transformId + "/ProcessSizedElementsAndRestrictions",
+                existingComponents::containsTransforms);
+        {
+          PTransform.Builder processSizedElementsAndRestrictions = PTransform.newBuilder();
+          processSizedElementsAndRestrictions.putInputs(mainInputName, splitAndSizeOutId);
+          processSizedElementsAndRestrictions.putAllInputs(sideInputs);
+          processSizedElementsAndRestrictions.putAllOutputs(splittableParDo.getOutputsMap());
+          processSizedElementsAndRestrictions.setUniqueName(
+              generateUniquePCollectonName(
+                  splittableParDo.getUniqueName() + "/ProcessSizedElementsAndRestrictions",
+                  existingComponents));
+          processSizedElementsAndRestrictions.setSpec(
+              FunctionSpec.newBuilder()
+                  .setUrn(
+                      PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+                  .setPayload(splittableParDo.getSpec().getPayload()));
+          rval.getComponentsBuilder()
+              .putTransforms(
+                  processSizedElementsAndRestrictionsId,
+                  processSizedElementsAndRestrictions.build());
+        }
+
+        PTransform.Builder newCompositeRoot =
+            splittableParDo
+                .toBuilder()
+                // Clear the original splittable ParDo spec and add all the new transforms as
+                // children.
+                .clearSpec()
+                .addAllSubtransforms(
+                    Arrays.asList(
+                        pairWithRestrictionId,
+                        splitAndSizeId,
+                        processSizedElementsAndRestrictionsId));
+        rval.setPtransform(newCompositeRoot);
+
+        return rval.build();
+      } catch (IOException e) {
+        throw new RuntimeException("Unable to perform expansion for transform " + transformId, e);
+      }
+    }
+  }
+
+  private static String getOrAddDoubleCoder(
+      ComponentsOrBuilder existingComponents, MessageWithComponents.Builder out) {
+    for (Map.Entry<String, Coder> coder : existingComponents.getCodersMap().entrySet()) {
+      if (ModelCoders.DOUBLE_CODER_URN.equals(coder.getValue().getSpec().getUrn())) {
+        return coder.getKey();
+      }
+    }
+    String doubleCoderId = generateUniqueId("DoubleCoder", existingComponents::containsCoders);
+    out.getComponentsBuilder()
+        .putCoders(
+            doubleCoderId,
+            Coder.newBuilder()
+                .setSpec(FunctionSpec.newBuilder().setUrn(ModelCoders.DOUBLE_CODER_URN))
+                .build());
+    return doubleCoderId;
+  }
+
+  /**
+   * Returns a PCollection name that uses the supplied prefix that does not exist in {@code
+   * existingComponents}.
+   */
+  private static String generateUniquePCollectonName(
+      String prefix, ComponentsOrBuilder existingComponents) {
+    return generateUniqueId(
+        prefix,
+        input -> {
+          for (PCollection pc : existingComponents.getPcollectionsMap().values()) {
+            if (input.equals(pc.getUniqueName())) {
+              return true;
+            }
+          }
+          return false;
+        });
+  }
+
+  /** Generates a unique id given a prefix and a predicate to compare if the id is already used. */
+  private static String generateUniqueId(String prefix, Predicate<String> isExistingId) {
+    int i = 0;
+    while (isExistingId.test(prefix + i)) {
+      i += 1;
+    }
+    return prefix + i;
+  }
+}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
index e149e66..7a8a51b 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
@@ -96,7 +96,7 @@
             Window.<Long>into(FixedWindows.of(Duration.standardMinutes(7)))
                 .triggering(
                     AfterWatermark.pastEndOfWindow()
-                        .withEarlyFirings(AfterPane.elementCountAtLeast(19)))
+                        .withLateFirings(AfterPane.elementCountAtLeast(19)))
                 .accumulatingFiredPanes()
                 .withAllowedLateness(Duration.standardMinutes(3L)));
     final WindowingStrategy<?, ?> windowedStrategy = windowed.getWindowingStrategy();
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpanderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpanderTest.java
new file mode 100644
index 0000000..5a8d125
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpanderTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction.graph;
+
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume;
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop;
+import static org.junit.Assert.assertEquals;
+
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class SplittableParDoExpanderTest {
+
+  @UnboundedPerElement
+  static class PairStringWithIndexToLengthBase extends DoFn<String, KV<String, Integer>> {
+    @ProcessElement
+    public ProcessContinuation process(
+        ProcessContext c, RestrictionTracker<OffsetRange, Long> tracker) {
+      for (long i = tracker.currentRestriction().getFrom(), numIterations = 0;
+          tracker.tryClaim(i);
+          ++i, ++numIterations) {
+        c.output(KV.of(c.element(), (int) i));
+        if (numIterations % 3 == 0) {
+          return resume();
+        }
+      }
+      return stop();
+    }
+
+    @GetInitialRestriction
+    public OffsetRange getInitialRange(String element) {
+      return new OffsetRange(0, element.length());
+    }
+
+    @SplitRestriction
+    public void splitRange(
+        String element, OffsetRange range, OutputReceiver<OffsetRange> receiver) {
+      receiver.output(new OffsetRange(range.getFrom(), (range.getFrom() + range.getTo()) / 2));
+      receiver.output(new OffsetRange((range.getFrom() + range.getTo()) / 2, range.getTo()));
+    }
+  }
+
+  @Test
+  public void testSizedReplacement() {
+    Pipeline p = Pipeline.create();
+    p.apply(Create.of("1", "2", "3"))
+        .apply("TestSDF", ParDo.of(new PairStringWithIndexToLengthBase()));
+
+    RunnerApi.Pipeline proto = PipelineTranslation.toProto(p);
+    String transformName =
+        Iterables.getOnlyElement(
+            Maps.filterValues(
+                    proto.getComponents().getTransformsMap(),
+                    (RunnerApi.PTransform transform) ->
+                        transform
+                            .getUniqueName()
+                            .contains(PairStringWithIndexToLengthBase.class.getSimpleName()))
+                .keySet());
+
+    RunnerApi.Pipeline updatedProto =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            proto,
+            SplittableParDoExpander.createSizedReplacement());
+    RunnerApi.PTransform newComposite =
+        updatedProto.getComponents().getTransformsOrThrow(transformName);
+    assertEquals(FunctionSpec.getDefaultInstance(), newComposite.getSpec());
+    assertEquals(3, newComposite.getSubtransformsCount());
+    assertEquals(
+        PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN,
+        updatedProto
+            .getComponents()
+            .getTransformsOrThrow(newComposite.getSubtransforms(0))
+            .getSpec()
+            .getUrn());
+    assertEquals(
+        PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN,
+        updatedProto
+            .getComponents()
+            .getTransformsOrThrow(newComposite.getSubtransforms(1))
+            .getSpec()
+            .getUrn());
+    assertEquals(
+        PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN,
+        updatedProto
+            .getComponents()
+            .getTransformsOrThrow(newComposite.getSubtransforms(2))
+            .getSpec()
+            .getUrn());
+  }
+}
diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index 3254a85..6281b94 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -200,6 +200,7 @@
       excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
       if (config.streaming) {
         excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
+        excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'  // BEAM-8598
         excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
       } else {
         excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java
index 212c75e..33d2c76 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java
@@ -26,10 +26,13 @@
 import javax.annotation.Nullable;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
 import org.apache.beam.runners.core.construction.graph.PipelineTrimmer;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
 import org.apache.beam.runners.core.metrics.MetricsPusher;
 import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineJarUtils;
 import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineResult;
@@ -87,8 +90,16 @@
           throws Exception {
     LOG.info("Translating pipeline to Flink program.");
 
+    // Expand any splittable ParDos within the graph to enable sizing and splitting of bundles.
+    Pipeline pipelineWithSdfExpanded =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            pipeline,
+            SplittableParDoExpander.createSizedReplacement());
+
     // Don't let the fuser fuse any subcomponents of native transforms.
-    Pipeline trimmedPipeline = PipelineTrimmer.trim(pipeline, translator.knownUrns());
+    Pipeline trimmedPipeline =
+        PipelineTrimmer.trim(pipelineWithSdfExpanded, translator.knownUrns());
 
     // Fused pipeline proto.
     // TODO: Consider supporting partially-fused graphs.
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 915acfb..eed52e9 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -94,6 +94,8 @@
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -246,7 +248,8 @@
                   () -> UUID.randomUUID().toString(),
                   keyedStateInternals,
                   getKeyedStateBackend(),
-                  stateBackendLock));
+                  stateBackendLock,
+                  keyCoder));
     } else {
       userStateRequestHandler = StateRequestHandler.unsupported();
     }
@@ -263,7 +266,12 @@
 
     private final StateInternals stateInternals;
     private final KeyedStateBackend<ByteBuffer> keyedStateBackend;
+    /** Lock to hold whenever accessing the state backend. */
     private final Lock stateBackendLock;
+    /** For debugging: The key coder used by the Runner. */
+    @Nullable private final Coder runnerKeyCoder;
+    /** For debugging: Same as keyedStateBackend but upcasted, to access key group meta info. */
+    @Nullable private final AbstractKeyedStateBackend<ByteBuffer> keyStateBackendWithKeyGroupInfo;
     /** Holds the valid cache token for user state for this operator. */
     private final ByteString cacheToken;
 
@@ -271,10 +279,20 @@
         IdGenerator cacheTokenGenerator,
         StateInternals stateInternals,
         KeyedStateBackend<ByteBuffer> keyedStateBackend,
-        Lock stateBackendLock) {
+        Lock stateBackendLock,
+        @Nullable Coder runnerKeyCoder) {
       this.stateInternals = stateInternals;
       this.keyedStateBackend = keyedStateBackend;
       this.stateBackendLock = stateBackendLock;
+      if (keyedStateBackend instanceof AbstractKeyedStateBackend) {
+        // This will always succeed, unless a custom state backend is used which does not extend
+        // AbstractKeyedStateBackend. This is unlikely but we should still consider this case.
+        this.keyStateBackendWithKeyGroupInfo =
+            (AbstractKeyedStateBackend<ByteBuffer>) keyedStateBackend;
+      } else {
+        this.keyStateBackendWithKeyGroupInfo = null;
+      }
+      this.runnerKeyCoder = runnerKeyCoder;
       this.cacheToken = ByteString.copyFrom(cacheTokenGenerator.getId().getBytes(Charsets.UTF_8));
     }
 
@@ -368,6 +386,19 @@
           // Key for state request is shipped encoded with NESTED context.
           ByteBuffer encodedKey = FlinkKeyUtils.fromEncodedKey(key);
           keyedStateBackend.setCurrentKey(encodedKey);
+          if (keyStateBackendWithKeyGroupInfo != null) {
+            int currentKeyGroupIndex = keyStateBackendWithKeyGroupInfo.getCurrentKeyGroupIndex();
+            KeyGroupRange keyGroupRange = keyStateBackendWithKeyGroupInfo.getKeyGroupRange();
+            Preconditions.checkState(
+                keyGroupRange.contains(currentKeyGroupIndex),
+                "The current key '%s' with key group index '%s' does not belong to the key group range '%s'. Runner keyCoder: %s. Ptransformid: %s Userstateid: %s",
+                Arrays.toString(key.toByteArray()),
+                currentKeyGroupIndex,
+                keyGroupRange,
+                runnerKeyCoder,
+                pTransformId,
+                userStateId);
+          }
         }
       };
     }
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index 0d7c99f..0cbf6b7 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -659,7 +659,7 @@
       ExecutableStageDoFnOperator.BagUserStateFactory<ByteString, Integer, GlobalWindow>
           bagUserStateFactory =
               new ExecutableStageDoFnOperator.BagUserStateFactory<>(
-                  cacheTokenGenerator, test, stateBackend, NoopLock.get());
+                  cacheTokenGenerator, test, stateBackend, NoopLock.get(), null);
 
       ByteString key1 = ByteString.copyFrom("key1", Charsets.UTF_8);
       ByteString key2 = ByteString.copyFrom("key2", Charsets.UTF_8);
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 3206870..dd12c22 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -1453,7 +1453,8 @@
   private static class ImpulseTranslator implements TransformTranslator<Impulse> {
     @Override
     public void translate(Impulse transform, TranslationContext context) {
-      if (context.getPipelineOptions().isStreaming() && !context.isFnApi()) {
+      if (context.getPipelineOptions().isStreaming()
+          && (!context.isFnApi() || !context.isStreamingEngine())) {
         StepTranslationContext stepContext = context.addStep(transform, "ParallelRead");
         stepContext.addInput(PropertyNames.FORMAT, "pubsub");
         stepContext.addInput(PropertyNames.PUBSUB_SUBSCRIPTION, "_starting_signal/");
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
index 5cfcb6e..71d49ac 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
@@ -25,6 +25,7 @@
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.annotations.Internal;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PCollection;
@@ -52,6 +53,13 @@
       return experiments != null && experiments.contains("beam_fn_api");
     }
 
+    default boolean isStreamingEngine() {
+      List<String> experiments = getPipelineOptions().getExperiments();
+      return experiments != null
+          && experiments.contains(GcpOptions.STREAMING_ENGINE_EXPERIMENT)
+          && experiments.contains(GcpOptions.WINDMILL_SERVICE_EXPERIMENT);
+    }
+
     /** Returns the configured pipeline options. */
     DataflowPipelineOptions getPipelineOptions();
 
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
index 7e298e7..dcf15f4 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
@@ -56,7 +56,6 @@
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
 import org.slf4j.Logger;
@@ -253,7 +252,7 @@
       }
 
       // Check to see if processing the request failed.
-      throwIfFailure(processResponse);
+      MoreFutures.get(processResponse);
 
       waitForInboundTermination.awaitCompletion();
       WindowedValue<KV<byte[], TargetWindowT>> sideInputWindow = outputValue.poll();
@@ -300,22 +299,10 @@
                               processBundleDescriptor.toBuilder().setId(descriptorId).build())
                           .build())
                   .build());
-      throwIfFailure(response);
+      // Check if the bundle descriptor is registered successfully.
+      MoreFutures.get(response);
       processBundleDescriptorId = descriptorId;
     }
     return processBundleDescriptorId;
   }
-
-  private static InstructionResponse throwIfFailure(
-      CompletionStage<InstructionResponse> responseFuture)
-      throws ExecutionException, InterruptedException {
-    InstructionResponse response = MoreFutures.get(responseFuture);
-    if (!Strings.isNullOrEmpty(response.getError())) {
-      throw new IllegalStateException(
-          String.format(
-              "Client failed to process %s with error [%s].",
-              response.getInstructionId(), response.getError()));
-    }
-    return response;
-  }
 }
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
index 73bbf4b..bf42c4d 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
@@ -297,7 +297,6 @@
                 .setRegister(registerRequest)
                 .build();
         registerFuture = instructionRequestHandler.handle(request);
-        getRegisterResponse(registerFuture);
       }
 
       checkState(
@@ -315,7 +314,10 @@
       deregisterStateHandler =
           beamFnStateDelegator.registerForProcessBundleInstructionId(
               getProcessBundleInstructionId(), this::delegateByStateKeyType);
-      processBundleResponse = instructionRequestHandler.handle(processBundleRequest);
+      processBundleResponse =
+          getRegisterResponse(registerFuture)
+              .thenCompose(
+                  registerResponse -> instructionRequestHandler.handle(processBundleRequest));
     }
   }
 
@@ -368,12 +370,8 @@
    * elements consumed from the upstream read operation.
    *
    * <p>May be called at any time, including before start() and after finish().
-   *
-   * @throws InterruptedException
-   * @throws ExecutionException
    */
-  public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress()
-      throws InterruptedException, ExecutionException {
+  public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress() {
     // processBundleId may be reset if this bundle finishes asynchronously.
     String processBundleId = this.processBundleId;
 
@@ -391,13 +389,7 @@
 
     return instructionRequestHandler
         .handle(processBundleRequest)
-        .thenApply(
-            response -> {
-              if (!response.getError().isEmpty()) {
-                throw new IllegalStateException(response.getError());
-              }
-              return response.getProcessBundleProgress();
-            });
+        .thenApply(InstructionResponse::getProcessBundleProgress);
   }
 
   /** Returns the final metrics returned by the SDK harness when it completes the bundle. */
@@ -634,53 +626,36 @@
     return true;
   }
 
-  private static CompletionStage<BeamFnApi.InstructionResponse> throwIfFailure(
+  private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
       CompletionStage<InstructionResponse> responseFuture) {
     return responseFuture.thenApply(
         response -> {
-          if (!response.getError().isEmpty()) {
-            throw new IllegalStateException(
-                String.format(
-                    "Client failed to process %s with error [%s].",
-                    response.getInstructionId(), response.getError()));
+          switch (response.getResponseCase()) {
+            case PROCESS_BUNDLE:
+              return response.getProcessBundle();
+            default:
+              throw new IllegalStateException(
+                  String.format(
+                      "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
+                      TextFormat.printToString(response)));
           }
-          return response;
         });
   }
 
-  private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
-      CompletionStage<InstructionResponse> responseFuture) {
-    return throwIfFailure(responseFuture)
-        .thenApply(
-            response -> {
-              switch (response.getResponseCase()) {
-                case PROCESS_BUNDLE:
-                  return response.getProcessBundle();
-                default:
-                  throw new IllegalStateException(
-                      String.format(
-                          "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
-                          TextFormat.printToString(response)));
-              }
-            });
-  }
-
   private static CompletionStage<BeamFnApi.RegisterResponse> getRegisterResponse(
-      CompletionStage<InstructionResponse> responseFuture)
-      throws ExecutionException, InterruptedException {
-    return throwIfFailure(responseFuture)
-        .thenApply(
-            response -> {
-              switch (response.getResponseCase()) {
-                case REGISTER:
-                  return response.getRegister();
-                default:
-                  throw new IllegalStateException(
-                      String.format(
-                          "SDK harness returned wrong kind of response to RegisterRequest: %s",
-                          TextFormat.printToString(response)));
-              }
-            });
+      CompletionStage<InstructionResponse> responseFuture) {
+    return responseFuture.thenApply(
+        response -> {
+          switch (response.getResponseCase()) {
+            case REGISTER:
+              return response.getRegister();
+            default:
+              throw new IllegalStateException(
+                  String.format(
+                      "SDK harness returned wrong kind of response to RegisterRequest: %s",
+                      TextFormat.printToString(response)));
+          }
+        });
   }
 
   private static void cancelIfNotNull(CompletionStage<?> future) {
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java
index 89d4595..aff3f86 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java
@@ -141,9 +141,7 @@
                 return MoreFutures.supplyAsync(
                     () -> {
                       processBundleLatch.await();
-                      return responseFor(request)
-                          .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
-                          .build();
+                      return responseFor(request).build();
                     });
               case PROCESS_BUNDLE_PROGRESS:
                 progressSentLatch.countDown();
@@ -238,9 +236,7 @@
                 return MoreFutures.supplyAsync(
                     () -> {
                       processBundleLatch.await();
-                      return responseFor(request)
-                          .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
-                          .build();
+                      return responseFor(request).build();
                     });
               case PROCESS_BUNDLE_PROGRESS:
                 progressSentTwiceLatch.countDown();
@@ -623,6 +619,20 @@
   }
 
   private BeamFnApi.InstructionResponse.Builder responseFor(BeamFnApi.InstructionRequest request) {
-    return BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+    BeamFnApi.InstructionResponse.Builder response =
+        BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+    if (request.hasRegister()) {
+      response.setRegister(BeamFnApi.RegisterResponse.getDefaultInstance());
+    } else if (request.hasProcessBundle()) {
+      response.setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
+    } else if (request.hasFinalizeBundle()) {
+      response.setFinalizeBundle(BeamFnApi.FinalizeBundleResponse.getDefaultInstance());
+    } else if (request.hasProcessBundleProgress()) {
+      response.setProcessBundleProgress(
+          BeamFnApi.ProcessBundleProgressResponse.getDefaultInstance());
+    } else if (request.hasProcessBundleSplit()) {
+      response.setProcessBundleSplit(BeamFnApi.ProcessBundleSplitResponse.getDefaultInstance());
+    }
+    return response;
   }
 }
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java
index 321a236..eb3d21d 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java
@@ -192,15 +192,8 @@
                 requests.add(request);
                 switch (request.getRequestCase()) {
                   case REGISTER:
-                    return CompletableFuture.completedFuture(
-                        responseFor(request)
-                            .setRegister(BeamFnApi.RegisterResponse.getDefaultInstance())
-                            .build());
                   case PROCESS_BUNDLE:
-                    return CompletableFuture.completedFuture(
-                        responseFor(request)
-                            .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
-                            .build());
+                    return CompletableFuture.completedFuture(responseFor(request).build());
                   default:
                     // block forever on other requests
                     return new CompletableFuture<>();
@@ -277,9 +270,7 @@
                 return MoreFutures.supplyAsync(
                     () -> {
                       processBundleLatch.await();
-                      return responseFor(request)
-                          .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
-                          .build();
+                      return responseFor(request).build();
                     });
               case PROCESS_BUNDLE_PROGRESS:
                 return CompletableFuture.completedFuture(
@@ -459,10 +450,7 @@
                 requests.add(request);
                 switch (request.getRequestCase()) {
                   case REGISTER:
-                    return CompletableFuture.completedFuture(
-                        InstructionResponse.newBuilder()
-                            .setInstructionId(request.getInstructionId())
-                            .build());
+                    return CompletableFuture.completedFuture(responseFor(request).build());
                   case PROCESS_BUNDLE:
                     CompletableFuture<InstructionResponse> responseFuture =
                         new CompletableFuture<>();
@@ -470,12 +458,7 @@
                         () -> {
                           // Purposefully sleep simulating SDK harness doing work
                           Thread.sleep(100);
-                          responseFuture.complete(
-                              InstructionResponse.newBuilder()
-                                  .setInstructionId(request.getInstructionId())
-                                  .setProcessBundle(
-                                      BeamFnApi.ProcessBundleResponse.getDefaultInstance())
-                                  .build());
+                          responseFuture.complete(responseFor(request).build());
                           completeFuture(request, responseFuture);
                           return null;
                         });
@@ -591,9 +574,7 @@
                           MoreFutures.get(stateHandler.handle(clear));
                       assertNotNull(clearResponse);
 
-                      return responseFor(request)
-                          .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
-                          .build();
+                      return responseFor(request).build();
                     });
               default:
                 // block forever
@@ -685,9 +666,7 @@
                           encodeAndConcat(Arrays.asList("X", "Y", "Z"), StringUtf8Coder.of()),
                           getResponse.getGet().getData());
 
-                      return responseFor(request)
-                          .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
-                          .build();
+                      return responseFor(request).build();
                     });
               default:
                 // block forever on other request types
@@ -855,15 +834,26 @@
   }
 
   private InstructionResponse.Builder responseFor(BeamFnApi.InstructionRequest request) {
-    return BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+    BeamFnApi.InstructionResponse.Builder response =
+        BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+    if (request.hasRegister()) {
+      response.setRegister(BeamFnApi.RegisterResponse.getDefaultInstance());
+    } else if (request.hasProcessBundle()) {
+      response.setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
+    } else if (request.hasFinalizeBundle()) {
+      response.setFinalizeBundle(BeamFnApi.FinalizeBundleResponse.getDefaultInstance());
+    } else if (request.hasProcessBundleProgress()) {
+      response.setProcessBundleProgress(
+          BeamFnApi.ProcessBundleProgressResponse.getDefaultInstance());
+    } else if (request.hasProcessBundleSplit()) {
+      response.setProcessBundleSplit(BeamFnApi.ProcessBundleSplitResponse.getDefaultInstance());
+    }
+    return response;
   }
 
   private void completeFuture(
       BeamFnApi.InstructionRequest request, CompletableFuture<InstructionResponse> response) {
-    response.complete(
-        BeamFnApi.InstructionResponse.newBuilder()
-            .setInstructionId(request.getInstructionId())
-            .build());
+    response.complete(responseFor(request).build());
   }
 
   @Test
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
new file mode 100644
index 0000000..1e5fa53
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+
+/**
+ * A handler which is invoked when the SDK returns {@link BeamFnApi.DelayedBundleApplication}s as
+ * part of the bundle completion.
+ *
+ * <p>These bundle applications must be resumed otherwise data loss will occur.
+ *
+ * <p>See <a href="https://s.apache.org/beam-breaking-fusion">breaking the fusion barrier</a> for
+ * further details.
+ */
+public interface BundleCheckpointHandler {
+  void onCheckpoint(BeamFnApi.ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
new file mode 100644
index 0000000..849663b
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+
+/**
+ * A handler for the runner when a finalization request has been received.
+ *
+ * <p>The runner is responsible for finalizing the bundle when all output from the bundle has been
+ * durably persisted.
+ *
+ * <p>See <a href="https://s.apache.org/beam-finalizing-bundles">finalizing bundles</a> for further
+ * details.
+ */
+public interface BundleFinalizationHandler {
+  void requestsFinalization(ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
index b655732..8a9dc75 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
@@ -20,7 +20,10 @@
 import java.util.concurrent.CompletionStage;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
 
-/** Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. */
+/**
+ * Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. Any
+ * error responses will be converted to exceptionally completed futures.
+ */
 public interface InstructionRequestHandler extends AutoCloseable {
   CompletionStage<BeamFnApi.InstructionResponse> handle(BeamFnApi.InstructionRequest request);
 }
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index b324cb1..a2129c4 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -57,6 +57,7 @@
 import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableTable;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -135,6 +136,10 @@
         forTimerSpecs(
             dataEndpoint, stage, components, inputDestinationsBuilder, remoteOutputCodersBuilder);
 
+    if (bagUserStateSpecs.size() > 0 || timerSpecs.size() > 0) {
+      lengthPrefixKeyCoder(stage.getInputPCollection().getId(), components);
+    }
+
     // Copy data from components to ProcessBundleDescriptor.
     ProcessBundleDescriptor.Builder bundleDescriptorBuilder =
         ProcessBundleDescriptor.newBuilder().setId(id);
@@ -158,6 +163,29 @@
         timerSpecs);
   }
 
+  /**
+   * Patches the input coder of a stateful Executable transform to ensure that the byte
+   * representation of a key used to partition the input element at the Runner, matches the key byte
+   * representation received for state requests and timers from the SDK Harness. Stateful transforms
+   * always have a KvCoder as input.
+   */
+  private static void lengthPrefixKeyCoder(
+      String inputColId, Components.Builder componentsBuilder) {
+    RunnerApi.PCollection pcollection = componentsBuilder.getPcollectionsOrThrow(inputColId);
+    RunnerApi.Coder kvCoder = componentsBuilder.getCodersOrThrow(pcollection.getCoderId());
+    Preconditions.checkState(
+        ModelCoders.KV_CODER_URN.equals(kvCoder.getSpec().getUrn()),
+        "Stateful executable stages must use a KV coder, but is: %s",
+        kvCoder.getSpec().getUrn());
+    String keyCoderId = ModelCoders.getKvCoderComponents(kvCoder).keyCoderId();
+    // Retain the original coder, but wrap in LengthPrefixCoder
+    String newKeyCoderId =
+        LengthPrefixUnknownCoders.addLengthPrefixedCoder(keyCoderId, componentsBuilder, false);
+    // Replace old key coder with LengthPrefixCoder<old_key_coder>
+    kvCoder = kvCoder.toBuilder().setComponentCoderIds(0, newKeyCoderId).build();
+    componentsBuilder.putCoders(pcollection.getCoderId(), kvCoder);
+  }
+
   private static Map<String, Coder<WindowedValue<?>>> addStageOutputs(
       ApiServiceDescriptor dataEndpoint,
       Collection<PCollectionNode> outputPCollections,
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
index b3f1c2e..86232dd 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
@@ -95,6 +95,9 @@
      *   // send all main input elements ...
      * }
      * }</pre>
+     *
+     * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+     * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
      */
     public ActiveBundle newBundle(
         Map<String, RemoteOutputReceiver<?>> outputReceivers,
@@ -122,6 +125,47 @@
      * try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
      *   FnDataReceiver<InputT> inputReceiver =
      *       (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
+     *   // send all main input elements ...
+     * }
+     * }</pre>
+     *
+     * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+     * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
+     */
+    public ActiveBundle newBundle(
+        Map<String, RemoteOutputReceiver<?>> outputReceivers,
+        StateRequestHandler stateRequestHandler,
+        BundleProgressHandler progressHandler) {
+      return newBundle(
+          outputReceivers,
+          stateRequestHandler,
+          progressHandler,
+          request -> {
+            throw new UnsupportedOperationException(
+                String.format(
+                    "The %s does not have a registered bundle checkpoint handler.",
+                    ActiveBundle.class.getSimpleName()));
+          },
+          request -> {
+            throw new UnsupportedOperationException(
+                String.format(
+                    "The %s does not have a registered bundle finalization handler.",
+                    ActiveBundle.class.getSimpleName()));
+          });
+    }
+
+    /**
+     * Start a new bundle for the given {@link BeamFnApi.ProcessBundleDescriptor} identifier.
+     *
+     * <p>The input channels for the returned {@link ActiveBundle} are derived from the instructions
+     * in the {@link BeamFnApi.ProcessBundleDescriptor}.
+     *
+     * <p>NOTE: It is important to {@link #close()} each bundle after all elements are emitted.
+     *
+     * <pre>{@code
+     * try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
+     *   FnDataReceiver<InputT> inputReceiver =
+     *       (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
      *   // send all elements ...
      * }
      * }</pre>
@@ -129,18 +173,22 @@
     public ActiveBundle newBundle(
         Map<String, RemoteOutputReceiver<?>> outputReceivers,
         StateRequestHandler stateRequestHandler,
-        BundleProgressHandler progressHandler) {
+        BundleProgressHandler progressHandler,
+        BundleCheckpointHandler checkpointHandler,
+        BundleFinalizationHandler finalizationHandler) {
       String bundleId = idGenerator.getId();
 
       final CompletionStage<BeamFnApi.InstructionResponse> genericResponse =
-          fnApiControlClient.handle(
-              BeamFnApi.InstructionRequest.newBuilder()
-                  .setInstructionId(bundleId)
-                  .setProcessBundle(
-                      BeamFnApi.ProcessBundleRequest.newBuilder()
-                          .setProcessBundleDescriptorId(processBundleDescriptor.getId())
-                          .addAllCacheTokens(stateRequestHandler.getCacheTokens()))
-                  .build());
+          registrationFuture.thenCompose(
+              registration ->
+                  fnApiControlClient.handle(
+                      BeamFnApi.InstructionRequest.newBuilder()
+                          .setInstructionId(bundleId)
+                          .setProcessBundle(
+                              BeamFnApi.ProcessBundleRequest.newBuilder()
+                                  .setProcessBundleDescriptorId(processBundleDescriptor.getId())
+                                  .addAllCacheTokens(stateRequestHandler.getCacheTokens()))
+                          .build()));
       LOG.debug(
           "Sent {} with ID {} for {} with ID {}",
           ProcessBundleRequest.class.getSimpleName(),
@@ -173,7 +221,9 @@
           dataReceiversBuilder.build(),
           outputClients,
           stateDelegator.registerForProcessBundleInstructionId(bundleId, stateRequestHandler),
-          progressHandler);
+          progressHandler,
+          checkpointHandler,
+          finalizationHandler);
     }
 
     private <OutputT> InboundDataClient attachReceiver(
@@ -191,6 +241,8 @@
     private final Map<String, InboundDataClient> outputClients;
     private final StateDelegator.Registration stateRegistration;
     private final BundleProgressHandler progressHandler;
+    private final BundleCheckpointHandler checkpointHandler;
+    private final BundleFinalizationHandler finalizationHandler;
 
     private ActiveBundle(
         String bundleId,
@@ -198,13 +250,17 @@
         Map<String, CloseableFnDataReceiver> inputReceivers,
         Map<String, InboundDataClient> outputClients,
         StateDelegator.Registration stateRegistration,
-        BundleProgressHandler progressHandler) {
+        BundleProgressHandler progressHandler,
+        BundleCheckpointHandler checkpointHandler,
+        BundleFinalizationHandler finalizationHandler) {
       this.bundleId = bundleId;
       this.response = response;
       this.inputReceivers = inputReceivers;
       this.outputClients = outputClients;
       this.stateRegistration = stateRegistration;
       this.progressHandler = progressHandler;
+      this.checkpointHandler = checkpointHandler;
+      this.finalizationHandler = finalizationHandler;
     }
 
     /** Returns an id used to represent this bundle. */
@@ -254,13 +310,15 @@
           BeamFnApi.ProcessBundleResponse completedResponse = MoreFutures.get(response);
           progressHandler.onCompleted(completedResponse);
           if (completedResponse.getResidualRootsCount() > 0) {
-            throw new IllegalStateException(
-                "TODO: [BEAM-2939] residual roots in process bundle response not yet supported.");
+            checkpointHandler.onCheckpoint(completedResponse);
+          }
+          if (completedResponse.getRequiresFinalization()) {
+            finalizationHandler.requestsFinalization(completedResponse);
           }
         } else {
           // TODO: [BEAM-3962] Handle aborting the bundle being processed.
           throw new IllegalStateException(
-              "Processing bundle failed, " + "TODO: [BEAM-3962] abort bundle.");
+              "Processing bundle failed, TODO: [BEAM-3962] abort bundle.");
         }
       } catch (Exception e) {
         if (exception == null) {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
new file mode 100644
index 0000000..b558b00
--- /dev/null
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import java.io.Serializable;
+import java.util.Map;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.CoderTranslation;
+import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
+import org.apache.beam.runners.core.construction.ModelCoders;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.FusedPipeline;
+import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Impulse;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+
+/** Tests for {@link ProcessBundleDescriptors}. */
+public class ProcessBundleDescriptorsTest implements Serializable {
+
+  /**
+   * Tests that a stateful Executable stage will wrap a key coder of a stateful transform in a
+   * LengthPrefixCoder.
+   */
+  @Test
+  public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder() throws Exception {
+    // Add another stateful stage with a non-standard key coder
+    Pipeline p = Pipeline.create();
+    Coder<Void> keycoder = VoidCoder.of();
+    assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
+    p.apply("impulse", Impulse.create())
+        .apply(
+            "create",
+            ParDo.of(
+                new DoFn<byte[], KV<Void, String>>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) {}
+                }))
+        .setCoder(KvCoder.of(keycoder, StringUtf8Coder.of()))
+        .apply(
+            "userState",
+            ParDo.of(
+                new DoFn<KV<Void, String>, KV<Void, String>>() {
+
+                  @StateId("stateId")
+                  private final StateSpec<BagState<String>> bufferState =
+                      StateSpecs.bag(StringUtf8Coder.of());
+
+                  @ProcessElement
+                  public void processElement(
+                      @Element KV<Void, String> element,
+                      @StateId("stateId") BagState<String> state,
+                      OutputReceiver<KV<Void, String>> r) {
+                    for (String value : state.read()) {
+                      r.output(KV.of(element.getKey(), value));
+                    }
+                    state.add(element.getValue());
+                  }
+                }))
+        // Force the output to be materialized
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    Optional<ExecutableStage> optionalStage =
+        Iterables.tryFind(
+            fused.getFusedStages(),
+            (ExecutableStage stage) ->
+                stage.getUserStates().stream()
+                    .anyMatch(spec -> spec.localName().equals("stateId")));
+    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+    ExecutableStage stage = optionalStage.get();
+
+    ProcessBundleDescriptors.ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance());
+
+    BeamFnApi.ProcessBundleDescriptor pbDescriptor = descriptor.getProcessBundleDescriptor();
+    String inputColId = stage.getInputPCollection().getId();
+    String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputColId).getCoderId();
+
+    Map<String, RunnerApi.Coder> codersMap = pbDescriptor.getCodersMap();
+    RunnerApi.Coder coder = codersMap.get(inputCoderId);
+    String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
+
+    assertThat(
+        codersMap.get(keyCoderId).getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+
+    RunnerApi.Coder orignalCoder = stage.getComponents().getCodersMap().get(inputCoderId);
+    String originalKeyCoderId = ModelCoders.getKvCoderComponents(orignalCoder).keyCoderId();
+    assertThat(
+        stage.getComponents().getCodersMap().get(originalKeyCoderId).getSpec().getUrn(),
+        is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+  }
+}
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 0925767..5d4c8f0 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -18,10 +18,10 @@
 package org.apache.beam.runners.fnexecution.control;
 
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -946,8 +946,6 @@
   @Test
   public void testExecutionWithTimer() throws Exception {
     Pipeline p = Pipeline.create();
-    final String timerId = "foo";
-    final String timerId2 = "foo2";
 
     p.apply("impulse", Impulse.create())
         .apply(
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
index 5bc0d1c..089c8d1 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
@@ -31,6 +31,7 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 
 import java.util.ArrayList;
@@ -41,6 +42,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
@@ -159,9 +161,8 @@
 
   @Test
   public void testRegisterCachesBundleProcessors() throws Exception {
-    CompletableFuture<InstructionResponse> registerResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(registerResponseFuture);
+        .thenReturn(createRegisterResponse());
 
     ProcessBundleDescriptor descriptor1 =
         ProcessBundleDescriptor.newBuilder().setId("descriptor1").build();
@@ -187,9 +188,8 @@
 
   @Test
   public void testRegisterWithStateRequiresStateDelegator() throws Exception {
-    CompletableFuture<InstructionResponse> registerResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(registerResponseFuture);
+        .thenReturn(createRegisterResponse());
 
     ProcessBundleDescriptor descriptor =
         ProcessBundleDescriptor.newBuilder()
@@ -214,7 +214,7 @@
   public void testNewBundleNoDataDoesNotCrash() throws Exception {
     CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(new CompletableFuture<>())
+        .thenReturn(createRegisterResponse())
         .thenReturn(processBundleResponseFuture);
 
     FullWindowedValueCoder<String> coder =
@@ -290,7 +290,7 @@
 
     CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(new CompletableFuture<>())
+        .thenReturn(createRegisterResponse())
         .thenReturn(processBundleResponseFuture);
 
     FullWindowedValueCoder<String> coder =
@@ -343,7 +343,7 @@
 
     CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(new CompletableFuture<>())
+        .thenReturn(createRegisterResponse())
         .thenReturn(processBundleResponseFuture);
 
     FullWindowedValueCoder<String> coder =
@@ -389,7 +389,7 @@
 
     CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(new CompletableFuture<>())
+        .thenReturn(createRegisterResponse())
         .thenReturn(processBundleResponseFuture);
 
     FullWindowedValueCoder<String> coder =
@@ -439,7 +439,7 @@
 
     CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(new CompletableFuture<>())
+        .thenReturn(createRegisterResponse())
         .thenReturn(processBundleResponseFuture);
 
     FullWindowedValueCoder<String> coder =
@@ -483,7 +483,7 @@
 
     CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(new CompletableFuture<>())
+        .thenReturn(createRegisterResponse())
         .thenReturn(processBundleResponseFuture);
 
     FullWindowedValueCoder<String> coder =
@@ -540,7 +540,7 @@
 
     CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(new CompletableFuture<>())
+        .thenReturn(createRegisterResponse())
         .thenReturn(processBundleResponseFuture);
 
     FullWindowedValueCoder<String> coder =
@@ -582,10 +582,9 @@
   }
 
   @Test
-  public void verifyCacheTokensAreUsedInNewBundleRequest() {
-    CompletableFuture<InstructionResponse> registerResponseFuture = new CompletableFuture<>();
+  public void verifyCacheTokensAreUsedInNewBundleRequest() throws InterruptedException {
     when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
-        .thenReturn(registerResponseFuture);
+        .thenReturn(createRegisterResponse());
 
     ProcessBundleDescriptor descriptor1 =
         ProcessBundleDescriptor.newBuilder().setId("descriptor1").build();
@@ -626,6 +625,117 @@
     assertThat(requests.get(1).getProcessBundle().getCacheTokensList(), is(cacheTokens));
   }
 
+  @Test
+  public void testBundleCheckpointCallback() throws Exception {
+    Exception testException = new Exception();
+
+    InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+    CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+    CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+    when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+        .thenReturn(createRegisterResponse())
+        .thenReturn(processBundleResponseFuture);
+
+    FullWindowedValueCoder<String> coder =
+        FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+    BundleProcessor processor =
+        sdkHarnessClient.getProcessor(
+            descriptor,
+            Collections.singletonMap(
+                "inputPC",
+                RemoteInputDestination.of(
+                    (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+    when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+    when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+    RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+    BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+    BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+    BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+    ProcessBundleResponse response =
+        ProcessBundleResponse.newBuilder()
+            .addResidualRoots(DelayedBundleApplication.getDefaultInstance())
+            .build();
+    ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+    try (ActiveBundle activeBundle =
+        processor.newBundle(
+            ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+            (request) -> {
+              throw new UnsupportedOperationException();
+            },
+            mockProgressHandler,
+            mockCheckpointHandler,
+            mockFinalizationHandler)) {
+      processBundleResponseFuture.complete(
+          InstructionResponse.newBuilder().setProcessBundle(response).build());
+    }
+
+    verify(mockProgressHandler).onCompleted(response);
+    verify(mockCheckpointHandler).onCheckpoint(response);
+    verifyZeroInteractions(mockFinalizationHandler);
+  }
+
+  @Test
+  public void testBundleFinalizationCallback() throws Exception {
+    Exception testException = new Exception();
+
+    InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+    CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+    CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+    when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+        .thenReturn(createRegisterResponse())
+        .thenReturn(processBundleResponseFuture);
+
+    FullWindowedValueCoder<String> coder =
+        FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+    BundleProcessor processor =
+        sdkHarnessClient.getProcessor(
+            descriptor,
+            Collections.singletonMap(
+                "inputPC",
+                RemoteInputDestination.of(
+                    (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+    when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+    when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+    RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+    BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+    BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+    BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+    ProcessBundleResponse response =
+        ProcessBundleResponse.newBuilder().setRequiresFinalization(true).build();
+    ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+    try (ActiveBundle activeBundle =
+        processor.newBundle(
+            ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+            (request) -> {
+              throw new UnsupportedOperationException();
+            },
+            mockProgressHandler,
+            mockCheckpointHandler,
+            mockFinalizationHandler)) {
+      processBundleResponseFuture.complete(
+          InstructionResponse.newBuilder().setProcessBundle(response).build());
+    }
+
+    verify(mockProgressHandler).onCompleted(response);
+    verify(mockFinalizationHandler).requestsFinalization(response);
+    verifyZeroInteractions(mockCheckpointHandler);
+  }
+
+  private CompletableFuture<InstructionResponse> createRegisterResponse() {
+    return CompletableFuture.completedFuture(
+        InstructionResponse.newBuilder()
+            .setRegister(BeamFnApi.RegisterResponse.getDefaultInstance())
+            .build());
+  }
+
   private static class TestFn extends DoFn<String, String> {
     @ProcessElement
     public void processElement(ProcessContext context) {
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java
index 85bd576..4733ef5 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java
@@ -19,7 +19,10 @@
 
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
 import org.apache.beam.runners.core.construction.renderer.PipelineDotRenderer;
 import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineResult;
 import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineRunner;
@@ -36,8 +39,16 @@
 
   @Override
   public PortablePipelineResult run(final Pipeline pipeline, JobInfo jobInfo) {
+    // Expand any splittable DoFns within the graph to enable sizing and splitting of bundles.
+    Pipeline pipelineWithSdfExpanded =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            pipeline,
+            SplittableParDoExpander.createSizedReplacement());
+
     // Fused pipeline proto.
-    final RunnerApi.Pipeline fusedPipeline = GreedyPipelineFuser.fuse(pipeline).toPipeline();
+    final RunnerApi.Pipeline fusedPipeline =
+        GreedyPipelineFuser.fuse(pipelineWithSdfExpanded).toPipeline();
     LOG.info("Portable pipeline to run:");
     LOG.info(PipelineDotRenderer.toDotString(fusedPipeline));
     // the pipeline option coming from sdk will set the sdk specific runner which will break
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
index 725df75..1d3f92d 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
@@ -25,9 +25,12 @@
 import java.util.concurrent.Future;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
 import org.apache.beam.runners.core.construction.graph.PipelineTrimmer;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
 import org.apache.beam.runners.core.metrics.MetricsPusher;
 import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineResult;
 import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineRunner;
@@ -58,8 +61,16 @@
   public PortablePipelineResult run(RunnerApi.Pipeline pipeline, JobInfo jobInfo) {
     SparkBatchPortablePipelineTranslator translator = new SparkBatchPortablePipelineTranslator();
 
+    // Expand any splittable DoFns within the graph to enable sizing and splitting of bundles.
+    Pipeline pipelineWithSdfExpanded =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            pipeline,
+            SplittableParDoExpander.createSizedReplacement());
+
     // Don't let the fuser fuse any subcomponents of native transforms.
-    Pipeline trimmedPipeline = PipelineTrimmer.trim(pipeline, translator.knownUrns());
+    Pipeline trimmedPipeline =
+        PipelineTrimmer.trim(pipelineWithSdfExpanded, translator.knownUrns());
 
     // Fused pipeline proto.
     // TODO: Consider supporting partially-fused graphs.
diff --git a/sdks/go/pkg/beam/pardo.go b/sdks/go/pkg/beam/pardo.go
index 41283f7..21e515a 100644
--- a/sdks/go/pkg/beam/pardo.go
+++ b/sdks/go/pkg/beam/pardo.go
@@ -123,7 +123,7 @@
 //    words := beam.ParDo(s, &Foo{...}, ...)
 //    lengths := beam.ParDo(s, func (word string) int) {
 //          return len(word)
-//    }, works)
+//    }, words)
 //
 //
 // Each output element has the same timestamp and is in the same windows as its
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
index 5e2d0bf..df9a5f4 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
@@ -87,6 +87,8 @@
   @VisibleForTesting
   static class TextBasedReader extends FileBasedReader<String> {
     private static final int READ_BUFFER_SIZE = 8192;
+    private static final ByteString UTF8_BOM =
+        ByteString.copyFrom(new byte[] {(byte) 0xEF, (byte) 0xBB, (byte) 0xBF});
     private final ByteBuffer readBuffer = ByteBuffer.allocate(READ_BUFFER_SIZE);
     private ByteString buffer;
     private int startOfDelimiterInBuffer;
@@ -251,6 +253,10 @@
      */
     private void decodeCurrentElement() throws IOException {
       ByteString dataToDecode = buffer.substring(0, startOfDelimiterInBuffer);
+      // If present, the UTF8 Byte Order Mark (BOM) will be removed.
+      if (startOfRecord == 0 && dataToDecode.startsWith(UTF8_BOM)) {
+        dataToDecode = dataToDecode.substring(UTF8_BOM.size());
+      }
       currentValue = dataToDecode.toStringUtf8();
       elementIsPresent = true;
       buffer = buffer.substring(endOfDelimiterInBuffer);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
new file mode 100644
index 0000000..55999ce
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Subcategory for {@link UsesTestStream} tests which use {@link TestStream} # across multiple
+ * stages. Some Runners do not properly support quiescence in a way that {@link TestStream} demands
+ * it.
+ */
+public interface UsesTestStreamWithMultipleStages extends UsesTestStream {}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
index 78efbf7..b53ca13 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
@@ -22,9 +22,12 @@
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark.AfterWatermarkEarlyAndLate;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark.FromEndOfWindow;
 import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.InvalidWindows;
+import org.apache.beam.sdk.transforms.windowing.Never.NeverTrigger;
 import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.transforms.windowing.WindowFn;
@@ -151,9 +154,8 @@
         && windowingStrategy.getTrigger() instanceof DefaultTrigger
         && input.isBounded() != IsBounded.BOUNDED) {
       throw new IllegalStateException(
-          "GroupByKey cannot be applied to non-bounded PCollection in "
-              + "the GlobalWindow without a trigger. Use a Window.into or Window.triggering transform "
-              + "prior to GroupByKey.");
+          "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without a"
+              + " trigger. Use a Window.into or Window.triggering transform prior to GroupByKey.");
     }
 
     // Validate the window merge function.
@@ -162,6 +164,45 @@
       throw new IllegalStateException(
           "GroupByKey must have a valid Window merge function.  " + "Invalid because: " + cause);
     }
+
+    // Validate that the trigger does not finish before garbage collection time
+    if (!triggerIsSafe(windowingStrategy)) {
+      throw new IllegalArgumentException(
+          String.format(
+              "Unsafe trigger may lose data, see"
+                  + " https://s.apache.org/finishing-triggers-drop-data: %s",
+              windowingStrategy.getTrigger()));
+    }
+  }
+
+  // Note that Never trigger finishes *at* GC time so it is OK, and
+  // AfterWatermark.fromEndOfWindow() finishes at end-of-window time so it is
+  // OK if there is no allowed lateness.
+  private static boolean triggerIsSafe(WindowingStrategy<?, ?> windowingStrategy) {
+    if (!windowingStrategy.getTrigger().mayFinish()) {
+      return true;
+    }
+
+    if (windowingStrategy.getTrigger() instanceof NeverTrigger) {
+      return true;
+    }
+
+    if (windowingStrategy.getTrigger() instanceof FromEndOfWindow
+        && windowingStrategy.getAllowedLateness().getMillis() == 0) {
+      return true;
+    }
+
+    if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate
+        && windowingStrategy.getAllowedLateness().getMillis() == 0) {
+      return true;
+    }
+
+    if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate
+        && ((AfterWatermarkEarlyAndLate) windowingStrategy.getTrigger()).getLateTrigger() != null) {
+      return true;
+    }
+
+    return false;
   }
 
   public WindowingStrategy<?, ?> updateWindowingStrategy(WindowingStrategy<?, ?> inputStrategy) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java
index 2ce2fdf..eb15888 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java
@@ -68,6 +68,11 @@
   }
 
   @Override
+  public boolean mayFinish() {
+    return subTriggers.stream().allMatch(trigger -> trigger.mayFinish());
+  }
+
+  @Override
   protected Trigger getContinuationTrigger(List<Trigger> continuationTriggers) {
     return Repeatedly.forever(new AfterFirst(continuationTriggers));
   }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java
index 2be41de..339b13b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java
@@ -119,6 +119,12 @@
       return window.maxTimestamp();
     }
 
+    /** @return true if there is no late firing set up, otherwise false */
+    @Override
+    public boolean mayFinish() {
+      return lateTrigger == null;
+    }
+
     @Override
     public String toString() {
       StringBuilder builder = new StringBuilder(TO_STRING);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java
index 39d5d13..e2aff9f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java
@@ -44,6 +44,12 @@
     return window.maxTimestamp();
   }
 
+  /** @return false; the default trigger never finishes */
+  @Override
+  public boolean mayFinish() {
+    return false;
+  }
+
   @Override
   public boolean isCompatible(Trigger other) {
     // Semantically, all default triggers are identical
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java
index a8f6659..d2ea3f1 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java
@@ -60,6 +60,11 @@
   }
 
   @Override
+  public boolean mayFinish() {
+    return subTriggers.get(ACTUAL).mayFinish() || subTriggers.get(UNTIL).mayFinish();
+  }
+
+  @Override
   protected Trigger getContinuationTrigger(List<Trigger> continuationTriggers) {
     // Use OrFinallyTrigger instead of AfterFirst because the continuation of ACTUAL
     // may not be a OnceTrigger.
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java
index be4dd53..9c54d75 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java
@@ -69,6 +69,11 @@
   }
 
   @Override
+  public boolean mayFinish() {
+    return false;
+  }
+
+  @Override
   protected Trigger getContinuationTrigger(List<Trigger> continuationTriggers) {
     return new Repeatedly(continuationTriggers.get(REPEATED));
   }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java
index ceb7011..63103e6 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java
@@ -52,6 +52,11 @@
   }
 
   @Override
+  public boolean mayFinish() {
+    return false;
+  }
+
+  @Override
   public String toString() {
     return "ReshuffleTrigger()";
   }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java
index ffddebd..639ba6c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java
@@ -137,6 +137,15 @@
   /**
    * <b><i>For internal use only; no backwards-compatibility guarantees.</i></b>
    *
+   * <p>Indicates whether this trigger may "finish". A top level trigger that finishes can cause
+   * data loss, so is rejected by GroupByKey validation.
+   */
+  @Internal
+  public abstract boolean mayFinish();
+
+  /**
+   * <b><i>For internal use only; no backwards-compatibility guarantees.</i></b>
+   *
    * <p>Returns whether this performs the same triggering as the given {@link Trigger}.
    */
   @Internal
@@ -230,6 +239,11 @@
     }
 
     @Override
+    public final boolean mayFinish() {
+      return true;
+    }
+
+    @Override
     public final OnceTrigger getContinuationTrigger() {
       Trigger continuation = super.getContinuationTrigger();
       if (!(continuation instanceof OnceTrigger)) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java
index b4772f3..8a73dd0 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java
@@ -16,5 +16,9 @@
  * limitations under the License.
  */
 
-/** Defines utilities that can be used by Beam runners. */
+/**
+ * <b>For internal use only; no backwards compatibility guarantees.</b>
+ *
+ * <p>Defines utilities that can be used by Beam runners.
+ */
 package org.apache.beam.sdk.util;
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
index e358ff9..2bce31b 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
@@ -639,6 +639,10 @@
   @Test
   @Category(NeedsRunner.class)
   public void testWindowedWritesWithOnceTrigger() throws Throwable {
+    p.enableAbandonedNodeEnforcement(false);
+    expectedException.expect(IllegalArgumentException.class);
+    expectedException.expectMessage("Unsafe trigger");
+
     // Tests for https://issues.apache.org/jira/browse/BEAM-3169
     PCollection<String> data =
         p.apply(Create.of("0", "1", "2"))
@@ -660,17 +664,6 @@
                     .<Void>withOutputFilenames())
             .getPerDestinationOutputFilenames()
             .apply(Values.create());
-
-    PAssert.that(
-            filenames
-                .apply(FileIO.matchAll())
-                .apply(FileIO.readMatches())
-                .apply(TextIO.readFiles()))
-        .containsInAnyOrder("0", "1", "2");
-
-    PAssert.that(filenames.apply(TextIO.readAll())).containsInAnyOrder("0", "1", "2");
-
-    p.run();
   }
 
   @Test
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java
new file mode 100644
index 0000000..36a3f68
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.FileIO.ReadableFile;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for TextSource class. */
+@RunWith(JUnit4.class)
+public class TextSourceTest {
+  @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testRemoveUtf8BOM() throws Exception {
+    Path p1 = createTestFile("test_txt_ascii", Charset.forName("US-ASCII"), "1,p1", "2,p1");
+    Path p2 =
+        createTestFile(
+            "test_txt_utf8_no_bom",
+            Charset.forName("UTF-8"),
+            "1,p2-Japanese:テスト",
+            "2,p2-Japanese:テスト");
+    Path p3 =
+        createTestFile(
+            "test_txt_utf8_bom",
+            Charset.forName("UTF-8"),
+            "\uFEFF1,p3-テストBOM",
+            "\uFEFF2,p3-テストBOM");
+    PCollection<String> contents =
+        pipeline
+            .apply("Create", Create.of(p1.toString(), p2.toString(), p3.toString()))
+            .setCoder(StringUtf8Coder.of())
+            // PCollection<String>
+            .apply("Read file", new TextFileReadTransform());
+    // PCollection<KV<String, String>>: tableName, line
+
+    // Validate that the BOM bytes (\uFEFF) at the beginning of the first line have been removed.
+    PAssert.that(contents)
+        .containsInAnyOrder(
+            "1,p1",
+            "2,p1",
+            "1,p2-Japanese:テスト",
+            "2,p2-Japanese:テスト",
+            "1,p3-テストBOM",
+            "\uFEFF2,p3-テストBOM");
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testPreserveNonBOMBytes() throws Exception {
+    // Contains \uFEFE, not UTF BOM.
+    Path p1 =
+        createTestFile(
+            "test_txt_utf_bom", Charset.forName("UTF-8"), "\uFEFE1,p1テスト", "\uFEFE2,p1テスト");
+    PCollection<String> contents =
+        pipeline
+            .apply("Create", Create.of(p1.toString()))
+            .setCoder(StringUtf8Coder.of())
+            // PCollection<String>
+            .apply("Read file", new TextFileReadTransform());
+
+    PAssert.that(contents).containsInAnyOrder("\uFEFE1,p1テスト", "\uFEFE2,p1テスト");
+
+    pipeline.run();
+  }
+
+  private static class FileReadDoFn extends DoFn<ReadableFile, String> {
+
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      ReadableFile file = c.element();
+      ValueProvider<String> filenameProvider =
+          ValueProvider.StaticValueProvider.of(file.getMetadata().resourceId().getFilename());
+      // Create a TextSource, passing null as the delimiter to use the default
+      // delimiters ('\n', '\r', or '\r\n').
+      TextSource textSource = new TextSource(filenameProvider, null, null);
+      try {
+        BoundedSource.BoundedReader<String> reader =
+            textSource
+                .createForSubrangeOfFile(file.getMetadata(), 0, file.getMetadata().sizeBytes())
+                .createReader(c.getPipelineOptions());
+        for (boolean more = reader.start(); more; more = reader.advance()) {
+          c.output(reader.getCurrent());
+        }
+      } catch (IOException e) {
+        throw new RuntimeException(
+            "Unable to readFile: " + file.getMetadata().resourceId().toString());
+      }
+    }
+  }
+
+  /** A transform that reads CSV file records. */
+  private static class TextFileReadTransform
+      extends PTransform<PCollection<String>, PCollection<String>> {
+    public TextFileReadTransform() {}
+
+    @Override
+    public PCollection<String> expand(PCollection<String> files) {
+      return files
+          // PCollection<String>
+          .apply(FileIO.matchAll().withEmptyMatchTreatment(EmptyMatchTreatment.DISALLOW))
+          // PCollection<Match.Metadata>
+          .apply(FileIO.readMatches())
+          // PCollection<FileIO.ReadableFile>
+          .apply("Read lines", ParDo.of(new FileReadDoFn()));
+      // PCollection<String>: line
+    }
+  }
+
+  private Path createTestFile(String filename, Charset charset, String... lines)
+      throws IOException {
+    Path path = Files.createTempFile(filename, ".csv");
+    try (BufferedWriter writer = Files.newBufferedWriter(path, charset)) {
+      for (String line : lines) {
+        writer.write(line);
+        writer.write('\n');
+      }
+    }
+    return path;
+  }
+}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
index 5e4cdcb..e48b6b2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.testing;
 
+import static org.apache.beam.sdk.transforms.windowing.Window.into;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -28,12 +29,22 @@
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.testing.TestStream.Builder;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Keys;
 import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.WithKeys;
@@ -44,6 +55,7 @@
 import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.transforms.windowing.Never;
 import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,8 +63,10 @@
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Rule;
@@ -263,7 +277,7 @@
     FixedWindows windows = FixedWindows.of(Duration.standardHours(6));
     PCollection<String> windowedValues =
         p.apply(stream)
-            .apply(Window.into(windows))
+            .apply(into(windows))
             .apply(WithKeys.of(1))
             .apply(GroupByKey.create())
             .apply(Values.create())
@@ -387,6 +401,74 @@
   }
 
   @Test
+  @Category({ValidatesRunner.class, UsesTestStream.class, UsesTestStreamWithMultipleStages.class})
+  public void testMultiStage() throws Exception {
+    TestStream<String> testStream =
+        TestStream.create(StringUtf8Coder.of())
+            .addElements("before") // before
+            .advanceWatermarkTo(Instant.ofEpochSecond(0)) // BEFORE
+            .addElements(TimestampedValue.of("after", Instant.ofEpochSecond(10))) // after
+            .advanceWatermarkToInfinity(); // AFTER
+
+    PCollection<String> input = p.apply(testStream);
+
+    PCollection<String> grouped =
+        input
+            .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))))
+            .apply(
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of(e, e)))
+            .apply(GroupByKey.create())
+            .apply(Keys.create())
+            .apply("Upper", MapElements.into(TypeDescriptors.strings()).via(String::toUpperCase))
+            .apply("Rewindow", Window.into(new GlobalWindows()));
+
+    PCollection<String> result =
+        PCollectionList.of(ImmutableList.of(input, grouped))
+            .apply(Flatten.pCollections())
+            .apply(
+                "Key",
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of("key", e)))
+            .apply(
+                ParDo.of(
+                    new DoFn<KV<String, String>, String>() {
+                      @StateId("seen")
+                      private final StateSpec<ValueState<String>> seenSpec =
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @TimerId("emit")
+                      private final TimerSpec emitSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+                      @ProcessElement
+                      public void process(
+                          ProcessContext context,
+                          @StateId("seen") ValueState<String> seenState,
+                          @TimerId("emit") Timer emitTimer) {
+                        String element = context.element().getValue();
+                        if (seenState.read() == null) {
+                          seenState.write(element);
+                        } else {
+                          seenState.write(seenState.read() + "," + element);
+                        }
+                        emitTimer.set(Instant.ofEpochSecond(100));
+                      }
+
+                      @OnTimer("emit")
+                      public void onEmit(
+                          OnTimerContext context, @StateId("seen") ValueState<String> seenState) {
+                        context.output(seenState.read());
+                      }
+                    }));
+
+    PAssert.that(result).containsInAnyOrder("before,BEFORE,after,AFTER");
+
+    p.run().waitUntilFinish();
+  }
+
+  @Test
   @Category(UsesTestStreamWithProcessingTime.class)
   public void testCoder() throws Exception {
     TestStream<String> testStream =
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
index 6d8408e..47af8a9 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
@@ -56,7 +56,9 @@
 import org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime;
 import org.apache.beam.sdk.testing.ValidatesRunner;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.windowing.AfterPane;
 import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -79,12 +81,14 @@
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
+import org.junit.experimental.runners.Enclosed;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Tests for GroupByKey. */
 @SuppressWarnings({"rawtypes", "unchecked"})
+@RunWith(Enclosed.class)
 public class GroupByKeyTest implements Serializable {
   /** Shared test base class with setup/teardown helpers. */
   public abstract static class SharedTestBase {
@@ -196,6 +200,104 @@
       input.apply(GroupByKey.create());
     }
 
+    // AfterPane.elementCountAtLeast(1) is not OK
+    @Test
+    public void testGroupByKeyFinishingTriggerRejected() {
+      PCollection<KV<String, String>> input =
+          p.apply(Create.of(KV.of("hello", "goodbye")))
+              .apply(
+                  Window.<KV<String, String>>configure()
+                      .discardingFiredPanes()
+                      .triggering(AfterPane.elementCountAtLeast(1)));
+
+      thrown.expect(IllegalArgumentException.class);
+      thrown.expectMessage("Unsafe trigger");
+      input.apply(GroupByKey.create());
+    }
+
+    // AfterWatermark.pastEndOfWindow() is OK with 0 allowed lateness
+    @Test
+    public void testGroupByKeyFinishingEndOfWindowTriggerOk() {
+      PCollection<KV<String, String>> input =
+          p.apply(Create.of(KV.of("hello", "goodbye")))
+              .apply(
+                  Window.<KV<String, String>>configure()
+                      .discardingFiredPanes()
+                      .triggering(AfterWatermark.pastEndOfWindow())
+                      .withAllowedLateness(Duration.ZERO));
+
+      // OK
+      input.apply(GroupByKey.create());
+    }
+
+    // AfterWatermark.pastEndOfWindow().withEarlyFirings() is OK with 0 allowed lateness
+    @Test
+    public void testGroupByKeyFinishingEndOfWindowEarlyFiringsTriggerOk() {
+      PCollection<KV<String, String>> input =
+          p.apply(Create.of(KV.of("hello", "goodbye")))
+              .apply(
+                  Window.<KV<String, String>>configure()
+                      .discardingFiredPanes()
+                      .triggering(
+                          AfterWatermark.pastEndOfWindow()
+                              .withEarlyFirings(AfterPane.elementCountAtLeast(1)))
+                      .withAllowedLateness(Duration.ZERO));
+
+      // OK
+      input.apply(GroupByKey.create());
+    }
+
+    // AfterWatermark.pastEndOfWindow() is not OK with > 0 allowed lateness
+    @Test
+    public void testGroupByKeyFinishingEndOfWindowTriggerNotOk() {
+      PCollection<KV<String, String>> input =
+          p.apply(Create.of(KV.of("hello", "goodbye")))
+              .apply(
+                  Window.<KV<String, String>>configure()
+                      .discardingFiredPanes()
+                      .triggering(AfterWatermark.pastEndOfWindow())
+                      .withAllowedLateness(Duration.millis(10)));
+
+      thrown.expect(IllegalArgumentException.class);
+      thrown.expectMessage("Unsafe trigger");
+      input.apply(GroupByKey.create());
+    }
+
+    // AfterWatermark.pastEndOfWindow().withEarlyFirings() is not OK with > 0 allowed lateness
+    @Test
+    public void testGroupByKeyFinishingEndOfWindowEarlyFiringsTriggerNotOk() {
+      PCollection<KV<String, String>> input =
+          p.apply(Create.of(KV.of("hello", "goodbye")))
+              .apply(
+                  Window.<KV<String, String>>configure()
+                      .discardingFiredPanes()
+                      .triggering(
+                          AfterWatermark.pastEndOfWindow()
+                              .withEarlyFirings(AfterPane.elementCountAtLeast(1)))
+                      .withAllowedLateness(Duration.millis(10)));
+
+      thrown.expect(IllegalArgumentException.class);
+      thrown.expectMessage("Unsafe trigger");
+      input.apply(GroupByKey.create());
+    }
+
+    // AfterWatermark.pastEndOfWindow().withLateFirings() is always OK
+    @Test
+    public void testGroupByKeyEndOfWindowLateFiringsOk() {
+      PCollection<KV<String, String>> input =
+          p.apply(Create.of(KV.of("hello", "goodbye")))
+              .apply(
+                  Window.<KV<String, String>>configure()
+                      .discardingFiredPanes()
+                      .triggering(
+                          AfterWatermark.pastEndOfWindow()
+                              .withLateFirings(AfterPane.elementCountAtLeast(1)))
+                      .withAllowedLateness(Duration.millis(10)));
+
+      // OK
+      input.apply(GroupByKey.create());
+    }
+
     @Test
     @Category(NeedsRunner.class)
     public void testRemerge() {
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java
index 36037c0..335d967 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java
@@ -68,6 +68,11 @@
     public Instant getWatermarkThatGuaranteesFiring(BoundedWindow window) {
       return null;
     }
+
+    @Override
+    public boolean mayFinish() {
+      return false;
+    }
   }
 
   private static class Trigger2 extends Trigger {
@@ -85,5 +90,10 @@
     public Instant getWatermarkThatGuaranteesFiring(BoundedWindow window) {
       return null;
     }
+
+    @Override
+    public boolean mayFinish() {
+      return false;
+    }
   }
 }
diff --git a/sdks/java/extensions/euphoria/build.gradle b/sdks/java/extensions/euphoria/build.gradle
index 93f6ebd..c79915b 100644
--- a/sdks/java/extensions/euphoria/build.gradle
+++ b/sdks/java/extensions/euphoria/build.gradle
@@ -30,7 +30,7 @@
   testCompile library.java.hamcrest_library
   testCompile library.java.mockito_core
   testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
 
 test {
diff --git a/sdks/java/extensions/jackson/build.gradle b/sdks/java/extensions/jackson/build.gradle
index 3d1e692..b36343c 100644
--- a/sdks/java/extensions/jackson/build.gradle
+++ b/sdks/java/extensions/jackson/build.gradle
@@ -32,5 +32,5 @@
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testCompile library.java.junit
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/extensions/join-library/build.gradle b/sdks/java/extensions/join-library/build.gradle
index c3c79e9..1257f8f 100644
--- a/sdks/java/extensions/join-library/build.gradle
+++ b/sdks/java/extensions/join-library/build.gradle
@@ -27,5 +27,5 @@
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testCompile library.java.junit
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/extensions/kryo/build.gradle b/sdks/java/extensions/kryo/build.gradle
index 24c8e0c..dc8c62a 100644
--- a/sdks/java/extensions/kryo/build.gradle
+++ b/sdks/java/extensions/kryo/build.gradle
@@ -43,7 +43,7 @@
     compile "com.esotericsoftware:kryo:${kryoVersion}"
     shadow project(path: ':sdks:java:core', configuration: 'shadow')
     testCompile project(path: ':sdks:java:core', configuration: 'shadowTest')
-    testRuntimeOnly project(':runners:direct-java')
+    testRuntimeOnly project(path: ':runners:direct-java', configuration: 'shadow')
 }
 
 test {
diff --git a/sdks/java/extensions/sketching/build.gradle b/sdks/java/extensions/sketching/build.gradle
index 1f403ee..d923501 100644
--- a/sdks/java/extensions/sketching/build.gradle
+++ b/sdks/java/extensions/sketching/build.gradle
@@ -36,5 +36,5 @@
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testCompile library.java.junit
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/extensions/sorter/build.gradle b/sdks/java/extensions/sorter/build.gradle
index 94cbba5..4994003 100644
--- a/sdks/java/extensions/sorter/build.gradle
+++ b/sdks/java/extensions/sorter/build.gradle
@@ -30,5 +30,5 @@
   testCompile library.java.hamcrest_library
   testCompile library.java.mockito_core
   testCompile library.java.junit
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle
index 78d7383..b75c339 100644
--- a/sdks/java/extensions/sql/build.gradle
+++ b/sdks/java/extensions/sql/build.gradle
@@ -45,7 +45,7 @@
   fmppTemplates library.java.vendored_calcite_1_20_0
   compile project(":sdks:java:core")
   compile project(":sdks:java:extensions:join-library")
-  compile project(":runners:direct-java")
+  compile project(path: ":runners:direct-java", configuration: "shadow")
   compile library.java.commons_csv
   compile library.java.vendored_calcite_1_20_0
   compile "com.alibaba:fastjson:1.2.49"
@@ -53,8 +53,10 @@
   compile "org.codehaus.janino:commons-compiler:3.0.11"
   provided project(":sdks:java:io:kafka")
   provided project(":sdks:java:io:google-cloud-platform")
+  compile project(":sdks:java:io:mongodb")
   provided project(":sdks:java:io:parquet")
   provided library.java.kafka_clients
+  runtimeOnly library.java.hadoop_client
   testCompile library.java.vendored_calcite_1_20_0
   testCompile library.java.vendored_guava_26_0_jre
   testCompile library.java.junit
@@ -62,6 +64,7 @@
   testCompile library.java.hamcrest_library
   testCompile library.java.mockito_core
   testCompile library.java.quickcheck_core
+  testCompile project(path: ":sdks:java:io:mongodb", configuration: "testRuntime")
   testRuntimeClasspath library.java.slf4j_jdk14
 }
 
@@ -156,6 +159,7 @@
 
   include '**/*IT.class'
   exclude '**/KafkaCSVTableIT.java'
+  exclude '**/MongoDbReadWriteIT.java'
   maxParallelForks 4
   classpath = project(":sdks:java:extensions:sql")
           .sourceSets
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java
new file mode 100644
index 0000000..eaa9661
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
+import java.io.Serializable;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
+import org.apache.beam.sdk.extensions.sql.meta.SchemaBaseBeamTable;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.io.mongodb.MongoDbIO;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.JsonToRow;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.annotations.VisibleForTesting;
+import org.bson.Document;
+import org.bson.json.JsonMode;
+import org.bson.json.JsonWriterSettings;
+
+@Experimental
+public class MongoDbTable extends SchemaBaseBeamTable implements Serializable {
+  // Should match: mongodb://username:password@localhost:27017/database/collection
+  @VisibleForTesting
+  final Pattern locationPattern =
+      Pattern.compile(
+          "(?<credsHostPort>mongodb://(?<usernamePassword>.*(?<password>:.*)?@)?.+:\\d+)/(?<database>.+)/(?<collection>.+)");
+
+  @VisibleForTesting final String dbCollection;
+  @VisibleForTesting final String dbName;
+  @VisibleForTesting final String dbUri;
+
+  MongoDbTable(Table table) {
+    super(table.getSchema());
+
+    String location = table.getLocation();
+    Matcher matcher = locationPattern.matcher(location);
+    checkArgument(
+        matcher.matches(),
+        "MongoDb location must be in the following format: 'mongodb://(username:password@)?localhost:27017/database/collection'");
+    this.dbUri = matcher.group("credsHostPort"); // "mongodb://localhost:27017"
+    this.dbName = matcher.group("database");
+    this.dbCollection = matcher.group("collection");
+  }
+
+  @Override
+  public PCollection<Row> buildIOReader(PBegin begin) {
+    // Read MongoDb Documents
+    PCollection<Document> readDocuments =
+        MongoDbIO.read()
+            .withUri(dbUri)
+            .withDatabase(dbName)
+            .withCollection(dbCollection)
+            .expand(begin);
+
+    return readDocuments.apply(DocumentToRow.withSchema(getSchema()));
+  }
+
+  @Override
+  public POutput buildIOWriter(PCollection<Row> input) {
+    throw new UnsupportedOperationException("Writing to a MongoDB is not supported");
+  }
+
+  @Override
+  public IsBounded isBounded() {
+    return IsBounded.BOUNDED;
+  }
+
+  @Override
+  public BeamTableStatistics getTableStatistics(PipelineOptions options) {
+    long count =
+        MongoDbIO.read()
+            .withUri(dbUri)
+            .withDatabase(dbName)
+            .withCollection(dbCollection)
+            .getDocumentCount();
+
+    if (count < 0) {
+      return BeamTableStatistics.BOUNDED_UNKNOWN;
+    }
+
+    return BeamTableStatistics.createBoundedTableStatistics((double) count);
+  }
+
+  public static class DocumentToRow extends PTransform<PCollection<Document>, PCollection<Row>> {
+    private final Schema schema;
+
+    private DocumentToRow(Schema schema) {
+      this.schema = schema;
+    }
+
+    public static DocumentToRow withSchema(Schema schema) {
+      return new DocumentToRow(schema);
+    }
+
+    @Override
+    public PCollection<Row> expand(PCollection<Document> input) {
+      // TODO(BEAM-8498): figure out a way convert Document directly to Row.
+      return input
+          .apply("Convert Document to JSON", ParDo.of(new DocumentToJsonStringConverter()))
+          .apply("Transform JSON to Row", JsonToRow.withSchema(schema))
+          .setRowSchema(schema);
+    }
+
+    // TODO: add support for complex fields (May require modifying how Calcite parses nested
+    // fields).
+    @VisibleForTesting
+    static class DocumentToJsonStringConverter extends DoFn<Document, String> {
+      @DoFn.ProcessElement
+      public void processElement(ProcessContext context) {
+        context.output(
+            context
+                .element()
+                .toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()));
+      }
+    }
+  }
+}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProvider.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProvider.java
new file mode 100644
index 0000000..ead09f0
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProvider.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import com.google.auto.service.AutoService;
+import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.extensions.sql.meta.provider.InMemoryMetaTableProvider;
+import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider;
+
+/**
+ * {@link TableProvider} for {@link MongoDbTable}.
+ *
+ * <p>A sample of MongoDb table is:
+ *
+ * <pre>{@code
+ * CREATE TABLE ORDERS(
+ *   name VARCHAR,
+ *   favorite_color VARCHAR,
+ *   favorite_numbers ARRAY<INTEGER>
+ * )
+ * TYPE 'mongodb'
+ * LOCATION 'mongodb://username:password@localhost:27017/database/collection'
+ * }</pre>
+ */
+@AutoService(TableProvider.class)
+public class MongoDbTableProvider extends InMemoryMetaTableProvider {
+
+  @Override
+  public String getTableType() {
+    return "mongodb";
+  }
+
+  @Override
+  public BeamSqlTable buildBeamSqlTable(Table table) {
+    return new MongoDbTable(table);
+  }
+}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/package-info.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/package-info.java
new file mode 100644
index 0000000..51c9a74
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/package-info.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Table schema for MongoDb. */
+@DefaultAnnotation(NonNull.class)
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import edu.umd.cs.findbugs.annotations.DefaultAnnotation;
+import edu.umd.cs.findbugs.annotations.NonNull;
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java
new file mode 100644
index 0000000..9c4d4cd
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BOOLEAN;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BYTE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.DOUBLE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.FLOAT;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT16;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT64;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.STRING;
+import static org.junit.Assert.assertEquals;
+
+import com.mongodb.MongoClient;
+import java.util.Arrays;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
+import org.apache.beam.sdk.io.mongodb.MongoDBIOIT.MongoDBPipelineOptions;
+import org.apache.beam.sdk.io.mongodb.MongoDbIO;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.transforms.ToJson;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.bson.Document;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * A test of {@link org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbTable} on an
+ * independent Mongo instance.
+ *
+ * <p>This test requires a running instance of MongoDB. Pass in connection information using
+ * PipelineOptions:
+ *
+ * <pre>
+ *  ./gradlew integrationTest -p sdks/java/extensions/sql/integrationTest -DintegrationTestPipelineOptions='[
+ *  "--mongoDBHostName=1.2.3.4",
+ *  "--mongoDBPort=27017",
+ *  "--mongoDBDatabaseName=mypass",
+ *  "--numberOfRecords=1000" ]'
+ *  --tests org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbReadWriteIT
+ *  -DintegrationTestRunner=direct
+ * </pre>
+ *
+ * A database, specified in the pipeline options, will be created implicitly if it does not exist
+ * already. And dropped upon completing tests.
+ *
+ * <p>Please see 'build_rules.gradle' file for instructions regarding running this test using Beam
+ * performance testing framework.
+ */
+@Ignore("https://issues.apache.org/jira/browse/BEAM-8586")
+public class MongoDbReadWriteIT {
+  private static final Schema SOURCE_SCHEMA =
+      Schema.builder()
+          .addNullableField("_id", STRING)
+          .addNullableField("c_bigint", INT64)
+          .addNullableField("c_tinyint", BYTE)
+          .addNullableField("c_smallint", INT16)
+          .addNullableField("c_integer", INT32)
+          .addNullableField("c_float", FLOAT)
+          .addNullableField("c_double", DOUBLE)
+          .addNullableField("c_boolean", BOOLEAN)
+          .addNullableField("c_varchar", STRING)
+          .addNullableField("c_arr", FieldType.array(STRING))
+          .build();
+  private static final String collection = "collection";
+  private static MongoDBPipelineOptions options;
+
+  @Rule public final TestPipeline writePipeline = TestPipeline.create();
+  @Rule public final TestPipeline readPipeline = TestPipeline.create();
+
+  @BeforeClass
+  public static void setUp() throws Exception {
+    PipelineOptionsFactory.register(MongoDBPipelineOptions.class);
+    options = TestPipeline.testingPipelineOptions().as(MongoDBPipelineOptions.class);
+  }
+
+  @AfterClass
+  public static void tearDown() throws Exception {
+    dropDatabase();
+  }
+
+  private static void dropDatabase() throws Exception {
+    new MongoClient(options.getMongoDBHostName())
+        .getDatabase(options.getMongoDBDatabaseName())
+        .drop();
+  }
+
+  @Test
+  public void testWriteAndRead() {
+    final String mongoUrl =
+        String.format("mongodb://%s:%d", options.getMongoDBHostName(), options.getMongoDBPort());
+    final String mongoSqlUrl =
+        String.format(
+            "mongodb://%s:%d/%s/%s",
+            options.getMongoDBHostName(),
+            options.getMongoDBPort(),
+            options.getMongoDBDatabaseName(),
+            collection);
+
+    Row testRow =
+        row(
+            SOURCE_SCHEMA,
+            "object_id",
+            9223372036854775807L,
+            (byte) 127,
+            (short) 32767,
+            2147483647,
+            (float) 1.0,
+            1.0,
+            true,
+            "varchar",
+            Arrays.asList("123", "456"));
+
+    writePipeline
+        .apply(Create.of(testRow))
+        .setRowSchema(SOURCE_SCHEMA)
+        .apply("Transform Rows to JSON", ToJson.of())
+        .apply("Produce documents from JSON", MapElements.via(new ObjectToDocumentFn()))
+        .apply(
+            "Write documents to MongoDB",
+            MongoDbIO.write()
+                .withUri(mongoUrl)
+                .withDatabase(options.getMongoDBDatabaseName())
+                .withCollection(collection));
+    PipelineResult writeResult = writePipeline.run();
+    writeResult.waitUntilFinish();
+
+    String createTableStatement =
+        "CREATE EXTERNAL TABLE TEST( \n"
+            + "   _id VARCHAR, \n "
+            + "   c_bigint BIGINT, \n "
+            + "   c_tinyint TINYINT, \n"
+            + "   c_smallint SMALLINT, \n"
+            + "   c_integer INTEGER, \n"
+            + "   c_float FLOAT, \n"
+            + "   c_double DOUBLE, \n"
+            + "   c_boolean BOOLEAN, \n"
+            + "   c_varchar VARCHAR, \n "
+            + "   c_arr ARRAY<VARCHAR> \n"
+            + ") \n"
+            + "TYPE 'mongodb' \n"
+            + "LOCATION '"
+            + mongoSqlUrl
+            + "'";
+
+    BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(new MongoDbTableProvider());
+    sqlEnv.executeDdl(createTableStatement);
+
+    PCollection<Row> output =
+        BeamSqlRelUtils.toPCollection(readPipeline, sqlEnv.parseQuery("select * from TEST"));
+
+    assertEquals(output.getSchema(), SOURCE_SCHEMA);
+
+    PAssert.that(output).containsInAnyOrder(testRow);
+
+    readPipeline.run().waitUntilFinish();
+  }
+
+  private static class ObjectToDocumentFn extends SimpleFunction<String, Document> {
+    @Override
+    public Document apply(String input) {
+      return Document.parse(input);
+    }
+  }
+
+  private Row row(Schema schema, Object... values) {
+    return Row.withSchema(schema).addValues(values).build();
+  }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProviderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProviderTest.java
new file mode 100644
index 0000000..459af56
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProviderTest.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.sdk.schemas.Schema.toSchema;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.util.stream.Stream;
+import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class MongoDbTableProviderTest {
+  private MongoDbTableProvider provider = new MongoDbTableProvider();
+
+  @Test
+  public void testGetTableType() {
+    assertEquals("mongodb", provider.getTableType());
+  }
+
+  @Test
+  public void testBuildBeamSqlTable() {
+    Table table = fakeTable("TEST", "mongodb://localhost:27017/database/collection");
+    BeamSqlTable sqlTable = provider.buildBeamSqlTable(table);
+
+    assertNotNull(sqlTable);
+    assertTrue(sqlTable instanceof MongoDbTable);
+
+    MongoDbTable mongoTable = (MongoDbTable) sqlTable;
+    assertEquals("mongodb://localhost:27017", mongoTable.dbUri);
+    assertEquals("database", mongoTable.dbName);
+    assertEquals("collection", mongoTable.dbCollection);
+  }
+
+  @Test
+  public void testBuildBeamSqlTable_withUsernameOnly() {
+    Table table = fakeTable("TEST", "mongodb://username@localhost:27017/database/collection");
+    BeamSqlTable sqlTable = provider.buildBeamSqlTable(table);
+
+    assertNotNull(sqlTable);
+    assertTrue(sqlTable instanceof MongoDbTable);
+
+    MongoDbTable mongoTable = (MongoDbTable) sqlTable;
+    assertEquals("mongodb://username@localhost:27017", mongoTable.dbUri);
+    assertEquals("database", mongoTable.dbName);
+    assertEquals("collection", mongoTable.dbCollection);
+  }
+
+  @Test
+  public void testBuildBeamSqlTable_withUsernameAndPassword() {
+    Table table =
+        fakeTable("TEST", "mongodb://username:pasword@localhost:27017/database/collection");
+    BeamSqlTable sqlTable = provider.buildBeamSqlTable(table);
+
+    assertNotNull(sqlTable);
+    assertTrue(sqlTable instanceof MongoDbTable);
+
+    MongoDbTable mongoTable = (MongoDbTable) sqlTable;
+    assertEquals("mongodb://username:pasword@localhost:27017", mongoTable.dbUri);
+    assertEquals("database", mongoTable.dbName);
+    assertEquals("collection", mongoTable.dbCollection);
+  }
+
+  @Test
+  public void testBuildBeamSqlTable_withBadLocation_throwsException() {
+    ImmutableList<String> badLocations =
+        ImmutableList.of(
+            "mongodb://localhost:27017/database/",
+            "mongodb://localhost:27017/database",
+            "localhost:27017/database/collection",
+            "mongodb://:27017/database/collection",
+            "mongodb://localhost:27017//collection",
+            "mongodb://localhost/database/collection",
+            "mongodb://localhost:/database/collection");
+
+    for (String badLocation : badLocations) {
+      Table table = fakeTable("TEST", badLocation);
+      assertThrows(IllegalArgumentException.class, () -> provider.buildBeamSqlTable(table));
+    }
+  }
+
+  private static Table fakeTable(String name, String location) {
+    return Table.builder()
+        .name(name)
+        .comment(name + " table")
+        .location(location)
+        .schema(
+            Stream.of(
+                    Schema.Field.nullable("id", Schema.FieldType.INT32),
+                    Schema.Field.nullable("name", Schema.FieldType.STRING))
+                .collect(toSchema()))
+        .type("mongodb")
+        .build();
+  }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableTest.java
new file mode 100644
index 0000000..cccac9c
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BOOLEAN;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BYTE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.DOUBLE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.FLOAT;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT16;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT64;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.STRING;
+
+import java.util.Arrays;
+import org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbTable.DocumentToRow;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.bson.Document;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class MongoDbTableTest {
+
+  private static final Schema SCHEMA =
+      Schema.builder()
+          .addNullableField("long", INT64)
+          .addNullableField("int32", INT32)
+          .addNullableField("int16", INT16)
+          .addNullableField("byte", BYTE)
+          .addNullableField("bool", BOOLEAN)
+          .addNullableField("double", DOUBLE)
+          .addNullableField("float", FLOAT)
+          .addNullableField("string", STRING)
+          .addRowField("nested", Schema.builder().addNullableField("int32", INT32).build())
+          .addNullableField("arr", FieldType.array(STRING))
+          .build();
+  private static final String JSON_ROW =
+      "{ "
+          + "\"long\" : 9223372036854775807, "
+          + "\"int32\" : 2147483647, "
+          + "\"int16\" : 32767, "
+          + "\"byte\" : 127, "
+          + "\"bool\" : true, "
+          + "\"double\" : 1.0, "
+          + "\"float\" : 1.0, "
+          + "\"string\" : \"string\", "
+          + "\"nested\" : {\"int32\" : 2147483645}, "
+          + "\"arr\" : [\"str1\", \"str2\", \"str3\"]"
+          + " }";
+
+  @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+  @Test
+  public void testDocumentToRowConverter() {
+    PCollection<Row> output =
+        pipeline
+            .apply("Create document from JSON", Create.<Document>of(Document.parse(JSON_ROW)))
+            .apply("CConvert document to Row", DocumentToRow.withSchema(SCHEMA));
+
+    // Make sure proper rows are constructed from JSON.
+    PAssert.that(output)
+        .containsInAnyOrder(
+            row(
+                SCHEMA,
+                9223372036854775807L,
+                2147483647,
+                (short) 32767,
+                (byte) 127,
+                true,
+                1.0,
+                (float) 1.0,
+                "string",
+                row(Schema.builder().addNullableField("int32", INT32).build(), 2147483645),
+                Arrays.asList("str1", "str2", "str3")));
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  private Row row(Schema schema, Object... values) {
+    return Row.withSchema(schema).addValues(values).build();
+  }
+}
diff --git a/sdks/java/extensions/zetasketch/build.gradle b/sdks/java/extensions/zetasketch/build.gradle
index 30e8bc8..e19da15 100644
--- a/sdks/java/extensions/zetasketch/build.gradle
+++ b/sdks/java/extensions/zetasketch/build.gradle
@@ -35,7 +35,7 @@
     testCompile library.java.junit
     testCompile project(":sdks:java:io:google-cloud-platform")
     testRuntimeOnly library.java.slf4j_simple
-    testRuntimeOnly project(":runners:direct-java")
+    testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
     testRuntimeOnly project(":runners:google-cloud-dataflow-java")
 }
 
diff --git a/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java b/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java
index 5a975da..e4851bd 100644
--- a/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java
+++ b/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java
@@ -18,6 +18,8 @@
 package org.apache.beam.sdk.extensions.zetasketch;
 
 import com.google.zetasketch.HyperLogLogPlusPlus;
+import java.nio.ByteBuffer;
+import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -108,6 +110,23 @@
   private HllCount() {}
 
   /**
+   * Converts the passed-in sketch from {@code ByteBuffer} to {@code byte[]}, mapping {@code null
+   * ByteBuffer}s (representing empty sketches) to empty {@code byte[]}s.
+   *
+   * <p>Utility method to convert sketches materialized with ZetaSQL/BigQuery to valid inputs for
+   * Beam {@code HllCount} transforms.
+   */
+  public static byte[] getSketchFromByteBuffer(@Nullable ByteBuffer bf) {
+    if (bf == null) {
+      return new byte[0];
+    } else {
+      byte[] result = new byte[bf.remaining()];
+      bf.get(result);
+      return result;
+    }
+  }
+
+  /**
    * Provides {@code PTransform}s to aggregate inputs into HLL++ sketches. The four supported input
    * types are {@code Integer}, {@code Long}, {@code String}, and {@code byte[]}.
    *
diff --git a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java
index 3f7927d..462a715 100644
--- a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java
+++ b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java
@@ -44,6 +44,7 @@
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -65,23 +66,32 @@
   private static final List<String> TEST_DATA =
       Arrays.asList("Apple", "Orange", "Banana", "Orange");
 
-  // Data Table: used by testReadSketchFromBigQuery())
+  // Data Table: used by tests reading sketches from BigQuery
   // Schema: only one STRING field named "data".
-  // Content: prepopulated with 4 rows: "Apple", "Orange", "Banana", "Orange"
-  private static final String DATA_TABLE_ID = "hll_data";
   private static final String DATA_FIELD_NAME = "data";
   private static final String DATA_FIELD_TYPE = "STRING";
   private static final String QUERY_RESULT_FIELD_NAME = "sketch";
-  private static final Long EXPECTED_COUNT = 3L;
 
-  // Sketch Table: used by testWriteSketchToBigQuery()
+  // Content: prepopulated with 4 rows: "Apple", "Orange", "Banana", "Orange"
+  private static final String DATA_TABLE_ID_NON_EMPTY = "hll_data_non_empty";
+  private static final Long EXPECTED_COUNT_NON_EMPTY = 3L;
+
+  // Content: empty
+  private static final String DATA_TABLE_ID_EMPTY = "hll_data_empty";
+  private static final Long EXPECTED_COUNT_EMPTY = 0L;
+
+  // Sketch Table: used by tests writing sketches to BigQuery
   // Schema: only one BYTES field named "sketch".
-  // Content: will be overridden by the sketch computed by the test pipeline each time the test runs
-  private static final String SKETCH_TABLE_ID = "hll_sketch";
   private static final String SKETCH_FIELD_NAME = "sketch";
   private static final String SKETCH_FIELD_TYPE = "BYTES";
+
+  // Content: will be overridden by the sketch computed by the test pipeline each time the test runs
+  private static final String SKETCH_TABLE_ID = "hll_sketch";
   // SHA-1 hash of string "[3]", the string representation of a row that has only one field 3 in it
-  private static final String EXPECTED_CHECKSUM = "f1e31df9806ce94c5bdbbfff9608324930f4d3f1";
+  private static final String EXPECTED_CHECKSUM_NON_EMPTY =
+      "f1e31df9806ce94c5bdbbfff9608324930f4d3f1";
+  // SHA-1 hash of string "[0]", the string representation of a row that has only one field 0 in it
+  private static final String EXPECTED_CHECKSUM_EMPTY = "1184f5b8d4b6dd08709cf1513f26744167065e0d";
 
   static {
     ApplicationNameOptions options =
@@ -93,31 +103,40 @@
   }
 
   @BeforeClass
-  public static void prepareDatasetAndDataTable() throws Exception {
+  public static void prepareDatasetAndDataTables() throws Exception {
     BIGQUERY_CLIENT.createNewDataset(PROJECT_ID, DATASET_ID);
 
-    // Create Data Table
     TableSchema dataTableSchema =
         new TableSchema()
             .setFields(
                 Collections.singletonList(
                     new TableFieldSchema().setName(DATA_FIELD_NAME).setType(DATA_FIELD_TYPE)));
-    Table dataTable =
+
+    Table dataTableNonEmpty =
         new Table()
             .setSchema(dataTableSchema)
             .setTableReference(
                 new TableReference()
                     .setProjectId(PROJECT_ID)
                     .setDatasetId(DATASET_ID)
-                    .setTableId(DATA_TABLE_ID));
-    BIGQUERY_CLIENT.createNewTable(PROJECT_ID, DATASET_ID, dataTable);
-
-    // Prepopulate test data to Data Table
+                    .setTableId(DATA_TABLE_ID_NON_EMPTY));
+    BIGQUERY_CLIENT.createNewTable(PROJECT_ID, DATASET_ID, dataTableNonEmpty);
+    // Prepopulates dataTableNonEmpty with TEST_DATA
     List<Map<String, Object>> rows =
         TEST_DATA.stream()
             .map(v -> Collections.singletonMap(DATA_FIELD_NAME, (Object) v))
             .collect(Collectors.toList());
-    BIGQUERY_CLIENT.insertDataToTable(PROJECT_ID, DATASET_ID, DATA_TABLE_ID, rows);
+    BIGQUERY_CLIENT.insertDataToTable(PROJECT_ID, DATASET_ID, DATA_TABLE_ID_NON_EMPTY, rows);
+
+    Table dataTableEmpty =
+        new Table()
+            .setSchema(dataTableSchema)
+            .setTableReference(
+                new TableReference()
+                    .setProjectId(PROJECT_ID)
+                    .setDatasetId(DATASET_ID)
+                    .setTableId(DATA_TABLE_ID_EMPTY));
+    BIGQUERY_CLIENT.createNewTable(PROJECT_ID, DATASET_ID, dataTableEmpty);
   }
 
   @AfterClass
@@ -126,22 +145,41 @@
   }
 
   /**
-   * Test that HLL++ sketch computed in BigQuery can be processed by Beam. Hll sketch is computed by
-   * {@code HLL_COUNT.INIT} in BigQuery and read into Beam; the test verifies that we can run {@link
-   * HllCount.MergePartial} and {@link HllCount.Extract} on the sketch in Beam to get the correct
-   * estimated count.
+   * Tests that a non-empty HLL++ sketch computed in BigQuery can be processed by Beam.
+   *
+   * <p>The Hll sketch is computed by {@code HLL_COUNT.INIT} in BigQuery and read into Beam; the
+   * test verifies that we can run {@link HllCount.MergePartial} and {@link HllCount.Extract} on the
+   * sketch in Beam to get the correct estimated count.
    */
   @Test
-  public void testReadSketchFromBigQuery() {
-    String tableSpec = String.format("%s.%s", DATASET_ID, DATA_TABLE_ID);
+  public void testReadNonEmptySketchFromBigQuery() {
+    readSketchFromBigQuery(DATA_TABLE_ID_NON_EMPTY, EXPECTED_COUNT_NON_EMPTY);
+  }
+
+  /**
+   * Tests that an empty HLL++ sketch computed in BigQuery can be processed by Beam.
+   *
+   * <p>The Hll sketch is computed by {@code HLL_COUNT.INIT} in BigQuery and read into Beam; the
+   * test verifies that we can run {@link HllCount.MergePartial} and {@link HllCount.Extract} on the
+   * sketch in Beam to get the correct estimated count.
+   */
+  @Test
+  public void testReadEmptySketchFromBigQuery() {
+    readSketchFromBigQuery(DATA_TABLE_ID_EMPTY, EXPECTED_COUNT_EMPTY);
+  }
+
+  private void readSketchFromBigQuery(String tableId, Long expectedCount) {
+    String tableSpec = String.format("%s.%s", DATASET_ID, tableId);
     String query =
         String.format(
             "SELECT HLL_COUNT.INIT(%s) AS %s FROM %s",
             DATA_FIELD_NAME, QUERY_RESULT_FIELD_NAME, tableSpec);
+
     SerializableFunction<SchemaAndRecord, byte[]> parseQueryResultToByteArray =
-        (SchemaAndRecord schemaAndRecord) ->
+        input ->
             // BigQuery BYTES type corresponds to Java java.nio.ByteBuffer type
-            ((ByteBuffer) schemaAndRecord.getRecord().get(QUERY_RESULT_FIELD_NAME)).array();
+            HllCount.getSketchFromByteBuffer(
+                (ByteBuffer) input.getRecord().get(QUERY_RESULT_FIELD_NAME));
 
     TestPipelineOptions options =
         TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class);
@@ -156,17 +194,35 @@
                     .withCoder(ByteArrayCoder.of()))
             .apply(HllCount.MergePartial.globally()) // no-op, only for testing MergePartial
             .apply(HllCount.Extract.globally());
-    PAssert.thatSingleton(result).isEqualTo(EXPECTED_COUNT);
+    PAssert.thatSingleton(result).isEqualTo(expectedCount);
     p.run().waitUntilFinish();
   }
 
   /**
-   * Test that HLL++ sketch computed in Beam can be processed by BigQuery. Hll sketch is computed by
-   * {@link HllCount.Init} in Beam and written to BigQuery; the test verifies that we can run {@code
-   * HLL_COUNT.EXTRACT()} on the sketch in BigQuery to get the correct estimated count.
+   * Tests that a non-empty HLL++ sketch computed in Beam can be processed by BigQuery.
+   *
+   * <p>The Hll sketch is computed by {@link HllCount.Init} in Beam and written to BigQuery; the
+   * test verifies that we can run {@code HLL_COUNT.EXTRACT()} on the sketch in BigQuery to get the
+   * correct estimated count.
    */
   @Test
-  public void testWriteSketchToBigQuery() {
+  public void testWriteNonEmptySketchToBigQuery() {
+    writeSketchToBigQuery(TEST_DATA, EXPECTED_CHECKSUM_NON_EMPTY);
+  }
+
+  /**
+   * Tests that an empty HLL++ sketch computed in Beam can be processed by BigQuery.
+   *
+   * <p>The Hll sketch is computed by {@link HllCount.Init} in Beam and written to BigQuery; the
+   * test verifies that we can run {@code HLL_COUNT.EXTRACT()} on the sketch in BigQuery to get the
+   * correct estimated count.
+   */
+  @Test
+  public void testWriteEmptySketchToBigQuery() {
+    writeSketchToBigQuery(Collections.emptyList(), EXPECTED_CHECKSUM_EMPTY);
+  }
+
+  private void writeSketchToBigQuery(List<String> testData, String expectedChecksum) {
     String tableSpec = String.format("%s.%s", DATASET_ID, SKETCH_TABLE_ID);
     String query =
         String.format("SELECT HLL_COUNT.EXTRACT(%s) FROM %s", SKETCH_FIELD_NAME, tableSpec);
@@ -181,16 +237,20 @@
     // After the pipeline finishes, BigqueryMatcher will send a query to retrieve the estimated
     // count and verifies its correctness using checksum.
     options.setOnSuccessMatcher(
-        BigqueryMatcher.createUsingStandardSql(APP_NAME, PROJECT_ID, query, EXPECTED_CHECKSUM));
+        BigqueryMatcher.createUsingStandardSql(APP_NAME, PROJECT_ID, query, expectedChecksum));
 
     Pipeline p = Pipeline.create(options);
-    p.apply(Create.of(TEST_DATA))
+    p.apply(Create.of(testData).withType(TypeDescriptor.of(String.class)))
         .apply(HllCount.Init.forStrings().globally())
         .apply(
             BigQueryIO.<byte[]>write()
                 .to(tableSpec)
                 .withSchema(tableSchema)
-                .withFormatFunction(sketch -> new TableRow().set(SKETCH_FIELD_NAME, sketch))
+                .withFormatFunction(
+                    sketch ->
+                        // Empty sketch is represented by empty byte array in Beam and by null in
+                        // BigQuery
+                        new TableRow().set(SKETCH_FIELD_NAME, sketch.length == 0 ? null : sketch))
                 .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE));
     p.run().waitUntilFinish();
   }
diff --git a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java
index 4ef3e6a..137e976 100644
--- a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java
+++ b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java
@@ -17,8 +17,11 @@
  */
 package org.apache.beam.sdk.extensions.zetasketch;
 
+import static org.junit.Assert.assertArrayEquals;
+
 import com.google.zetasketch.HyperLogLogPlusPlus;
 import com.google.zetasketch.shaded.com.google.protobuf.ByteString;
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -484,4 +487,15 @@
     PAssert.thatSingleton(result).isEqualTo(KV.of("k", 0L));
     p.run();
   }
+
+  @Test
+  public void testGetSketchFromByteBufferForEmptySketch() {
+    assertArrayEquals(HllCount.getSketchFromByteBuffer(null), EMPTY_SKETCH);
+  }
+
+  @Test
+  public void testGetSketchFromByteBufferForNonEmptySketch() {
+    ByteBuffer bf = ByteBuffer.wrap(INTS1_SKETCH);
+    assertArrayEquals(HllCount.getSketchFromByteBuffer(bf), INTS1_SKETCH);
+  }
 }
diff --git a/sdks/java/io/amazon-web-services/build.gradle b/sdks/java/io/amazon-web-services/build.gradle
index ca88447..d7e4139 100644
--- a/sdks/java/io/amazon-web-services/build.gradle
+++ b/sdks/java/io/amazon-web-services/build.gradle
@@ -50,7 +50,7 @@
   testCompile 'org.elasticmq:elasticmq-rest-sqs_2.12:0.14.1'
   testCompile 'org.testcontainers:localstack:1.11.2'
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
 
 test {
diff --git a/sdks/java/io/amazon-web-services2/build.gradle b/sdks/java/io/amazon-web-services2/build.gradle
index eb33c56..a52c827 100644
--- a/sdks/java/io/amazon-web-services2/build.gradle
+++ b/sdks/java/io/amazon-web-services2/build.gradle
@@ -42,7 +42,7 @@
   testCompile library.java.junit
   testCompile 'org.testcontainers:testcontainers:1.11.3'
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
 
 test {
diff --git a/sdks/java/io/amqp/build.gradle b/sdks/java/io/amqp/build.gradle
index a4c35a3..6a6b3a5 100644
--- a/sdks/java/io/amqp/build.gradle
+++ b/sdks/java/io/amqp/build.gradle
@@ -35,5 +35,5 @@
   testCompile library.java.activemq_amqp
   testCompile library.java.activemq_junit
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/cassandra/build.gradle b/sdks/java/io/cassandra/build.gradle
index ca3015c4..36dbede 100644
--- a/sdks/java/io/cassandra/build.gradle
+++ b/sdks/java/io/cassandra/build.gradle
@@ -45,5 +45,5 @@
   testCompile group: 'info.archinnov', name: 'achilles-junit', version: "$achilles_version"
   testCompile library.java.jackson_jaxb_annotations
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/clickhouse/build.gradle b/sdks/java/io/clickhouse/build.gradle
index ee8d382..ec102b4 100644
--- a/sdks/java/io/clickhouse/build.gradle
+++ b/sdks/java/io/clickhouse/build.gradle
@@ -60,5 +60,5 @@
   testCompile library.java.hamcrest_library
   testCompile "org.testcontainers:clickhouse:$testcontainers_version"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
index 0788f89..482fc7d 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
@@ -48,5 +48,5 @@
   testCompile library.java.vendored_guava_26_0_jre
   testCompile "org.elasticsearch:elasticsearch:$elastic_search_version"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
index 543c0db..2e13700 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
@@ -65,5 +65,5 @@
   testCompile library.java.junit
   testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
index 93b59b5..b7bf6d0 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
@@ -65,5 +65,5 @@
   testCompile library.java.junit
   testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
index 168fc5b..53a1ff4 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
@@ -48,5 +48,5 @@
   testCompile library.java.junit
   testCompile "org.elasticsearch.client:elasticsearch-rest-client:6.4.0"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/file-based-io-tests/build.gradle b/sdks/java/io/file-based-io-tests/build.gradle
index 845a9a8..fa737cb 100644
--- a/sdks/java/io/file-based-io-tests/build.gradle
+++ b/sdks/java/io/file-based-io-tests/build.gradle
@@ -33,4 +33,5 @@
   testCompile library.java.junit
   testCompile library.java.hamcrest_core
   testCompile library.java.jaxb_api
+  testRuntime library.java.hadoop_client
 }
diff --git a/sdks/java/io/hadoop-file-system/build.gradle b/sdks/java/io/hadoop-file-system/build.gradle
index 8ebdc93..26f9db3 100644
--- a/sdks/java/io/hadoop-file-system/build.gradle
+++ b/sdks/java/io/hadoop-file-system/build.gradle
@@ -39,5 +39,5 @@
   testCompile library.java.hadoop_minicluster
   testCompile library.java.hadoop_hdfs_tests
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/hadoop-format/build.gradle b/sdks/java/io/hadoop-format/build.gradle
index 20dba8d..d575d40 100644
--- a/sdks/java/io/hadoop-format/build.gradle
+++ b/sdks/java/io/hadoop-format/build.gradle
@@ -82,7 +82,7 @@
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
   compile library.java.commons_io_2x
 
   delegate.add("sparkRunner", project(":sdks:java:io:hadoop-format"))
diff --git a/sdks/java/io/hbase/build.gradle b/sdks/java/io/hbase/build.gradle
index 882eb5e..e2dd902 100644
--- a/sdks/java/io/hbase/build.gradle
+++ b/sdks/java/io/hbase/build.gradle
@@ -64,6 +64,6 @@
   }
   testCompile "org.apache.hbase:hbase-hadoop-compat:$hbase_version:tests"
   testCompile "org.apache.hbase:hbase-hadoop2-compat:$hbase_version:tests"
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
 
diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle
index da0ee24..e32277d 100644
--- a/sdks/java/io/hcatalog/build.gradle
+++ b/sdks/java/io/hcatalog/build.gradle
@@ -67,6 +67,6 @@
   testCompile "org.apache.hive:hive-exec:$hive_version"
   testCompile "org.apache.hive:hive-common:$hive_version"
   testCompile "org.apache.hive:hive-cli:$hive_version"
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
 
diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle
index d8ce5f2..ad59ab1 100644
--- a/sdks/java/io/jdbc/build.gradle
+++ b/sdks/java/io/jdbc/build.gradle
@@ -40,5 +40,5 @@
   testCompile "org.apache.derby:derbyclient:10.14.2.0"
   testCompile "org.apache.derby:derbynet:10.14.2.0"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/jms/build.gradle b/sdks/java/io/jms/build.gradle
index 8056106..3886d23 100644
--- a/sdks/java/io/jms/build.gradle
+++ b/sdks/java/io/jms/build.gradle
@@ -37,5 +37,5 @@
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle
index da8655e..d2d79c7 100644
--- a/sdks/java/io/kafka/build.gradle
+++ b/sdks/java/io/kafka/build.gradle
@@ -46,5 +46,5 @@
   testCompile library.java.powermock
   testCompile library.java.powermock_mockito
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/kinesis/build.gradle b/sdks/java/io/kinesis/build.gradle
index a73c770..ccc0dce 100644
--- a/sdks/java/io/kinesis/build.gradle
+++ b/sdks/java/io/kinesis/build.gradle
@@ -50,5 +50,5 @@
   testCompile library.java.powermock_mockito
   testCompile "org.assertj:assertj-core:3.11.1"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/kudu/build.gradle b/sdks/java/io/kudu/build.gradle
index 8619d67..b32a02d 100644
--- a/sdks/java/io/kudu/build.gradle
+++ b/sdks/java/io/kudu/build.gradle
@@ -45,6 +45,6 @@
   testCompile library.java.hamcrest_library
   testCompile library.java.junit
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
 
diff --git a/sdks/java/io/mongodb/build.gradle b/sdks/java/io/mongodb/build.gradle
index d8ac11c..040f88e 100644
--- a/sdks/java/io/mongodb/build.gradle
+++ b/sdks/java/io/mongodb/build.gradle
@@ -38,5 +38,5 @@
   testCompile "de.flapdoodle.embed:de.flapdoodle.embed.mongo:2.2.0"
   testCompile "de.flapdoodle.embed:de.flapdoodle.embed.process:2.1.2"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
index 4bdbbf4..9fd06d3 100644
--- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
@@ -323,6 +323,13 @@
       return input.apply(org.apache.beam.sdk.io.Read.from(new BoundedMongoDbSource(this)));
     }
 
+    public long getDocumentCount() {
+      checkArgument(uri() != null, "withUri() is required");
+      checkArgument(database() != null, "withDatabase() is required");
+      checkArgument(collection() != null, "withCollection() is required");
+      return new BoundedMongoDbSource(this).getDocumentCount();
+    }
+
     @Override
     public void populateDisplayData(DisplayData.Builder builder) {
       super.populateDisplayData(builder);
@@ -376,6 +383,38 @@
       return new BoundedMongoDbReader(this);
     }
 
+    /**
+     * Returns number of Documents in a collection.
+     *
+     * @return Positive number of Documents in a collection or -1 on error.
+     */
+    long getDocumentCount() {
+      try (MongoClient mongoClient =
+          new MongoClient(
+              new MongoClientURI(
+                  spec.uri(),
+                  getOptions(
+                      spec.maxConnectionIdleTime(),
+                      spec.sslEnabled(),
+                      spec.sslInvalidHostNameAllowed())))) {
+        return getDocumentCount(mongoClient, spec.database(), spec.collection());
+      } catch (Exception e) {
+        return -1;
+      }
+    }
+
+    private long getDocumentCount(MongoClient mongoClient, String database, String collection) {
+      MongoDatabase mongoDatabase = mongoClient.getDatabase(database);
+
+      // get the Mongo collStats object
+      // it gives the size for the entire collection
+      BasicDBObject stat = new BasicDBObject();
+      stat.append("collStats", collection);
+      Document stats = mongoDatabase.runCommand(stat);
+
+      return stats.get("count", Number.class).longValue();
+    }
+
     @Override
     public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) {
       try (MongoClient mongoClient =
diff --git a/sdks/java/io/mqtt/build.gradle b/sdks/java/io/mqtt/build.gradle
index 9d7b188..a384274 100644
--- a/sdks/java/io/mqtt/build.gradle
+++ b/sdks/java/io/mqtt/build.gradle
@@ -37,5 +37,5 @@
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/parquet/build.gradle b/sdks/java/io/parquet/build.gradle
index 7a038f3..2e14f77 100644
--- a/sdks/java/io/parquet/build.gradle
+++ b/sdks/java/io/parquet/build.gradle
@@ -32,12 +32,11 @@
   compile "org.apache.parquet:parquet-common:$parquet_version"
   compile "org.apache.parquet:parquet-hadoop:$parquet_version"
   compile library.java.avro
-  compile library.java.hadoop_client
-  compile library.java.hadoop_common
+  provided library.java.hadoop_client
   testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
   testCompile library.java.junit
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/rabbitmq/build.gradle b/sdks/java/io/rabbitmq/build.gradle
index 57415f9..cf47712 100644
--- a/sdks/java/io/rabbitmq/build.gradle
+++ b/sdks/java/io/rabbitmq/build.gradle
@@ -36,5 +36,5 @@
   testCompile library.java.hamcrest_library
   testCompile library.java.slf4j_api
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/redis/build.gradle b/sdks/java/io/redis/build.gradle
index 2626400..e205155 100644
--- a/sdks/java/io/redis/build.gradle
+++ b/sdks/java/io/redis/build.gradle
@@ -32,5 +32,5 @@
   testCompile library.java.hamcrest_library
   testCompile "com.github.kstyrc:embedded-redis:0.6"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/solr/build.gradle b/sdks/java/io/solr/build.gradle
index 928d6db..c2a9ddf 100644
--- a/sdks/java/io/solr/build.gradle
+++ b/sdks/java/io/solr/build.gradle
@@ -38,5 +38,5 @@
   testCompile "org.apache.solr:solr-core:5.5.4"
   testCompile "com.carrotsearch.randomizedtesting:randomizedtesting-runner:2.3.2"
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/synthetic/build.gradle b/sdks/java/io/synthetic/build.gradle
index 52c794b..e1d2abc 100644
--- a/sdks/java/io/synthetic/build.gradle
+++ b/sdks/java/io/synthetic/build.gradle
@@ -34,5 +34,5 @@
   testCompile library.java.junit
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/tika/build.gradle b/sdks/java/io/tika/build.gradle
index 813a692..1aae4d1 100644
--- a/sdks/java/io/tika/build.gradle
+++ b/sdks/java/io/tika/build.gradle
@@ -36,5 +36,5 @@
   testCompile library.java.hamcrest_library
   testCompile "org.apache.tika:tika-parsers:$tika_version"
   testCompileOnly "biz.aQute:bndlib:$bndlib_version"
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/xml/build.gradle b/sdks/java/io/xml/build.gradle
index 2cb9077..7b59671 100644
--- a/sdks/java/io/xml/build.gradle
+++ b/sdks/java/io/xml/build.gradle
@@ -33,5 +33,5 @@
   testCompile library.java.hamcrest_core
   testCompile library.java.hamcrest_library
   testRuntimeOnly library.java.slf4j_jdk14
-  testRuntimeOnly project(":runners:direct-java")
+  testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/testing/expansion-service/build.gradle b/sdks/java/testing/expansion-service/build.gradle
index bfbcc23..09d830c 100644
--- a/sdks/java/testing/expansion-service/build.gradle
+++ b/sdks/java/testing/expansion-service/build.gradle
@@ -28,6 +28,7 @@
   compile project(path: ":runners:core-construction-java")
   compile project(path: ":sdks:java:io:parquet")
   compile project(path: ":sdks:java:core", configuration: "shadow")
+  runtimeOnly library.java.hadoop_client
 }
 
 task runTestExpansionService (type: JavaExec) {
diff --git a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java
index 89b0cc6..aa133e9 100644
--- a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java
+++ b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java
@@ -41,6 +41,7 @@
 import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
 import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing;
@@ -314,7 +315,7 @@
             name + ".WindowLogFiles",
             Window.<KV<Void, OutputFile>>into(
                     FixedWindows.of(Duration.standardSeconds(configuration.windowSizeSec)))
-                .triggering(AfterWatermark.pastEndOfWindow())
+                .triggering(DefaultTrigger.of())
                 // We expect no late data here, but we'll assume the worst so we can detect any.
                 .withAllowedLateness(Duration.standardDays(1))
                 .discardingFiredPanes())
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 5b66730..e21052f 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -35,6 +35,7 @@
 import logging
 import math
 import random
+import threading
 import uuid
 from builtins import object
 from builtins import range
@@ -1104,13 +1105,17 @@
 class RestrictionTracker(object):
   """Manages concurrent access to a restriction.
 
-  Experimental; no backwards-compatibility guarantees.
-
   Keeps track of the restrictions claimed part for a Splittable DoFn.
 
+  The restriction may be modified by different threads, however the system will
+  ensure sufficient locking such that no methods on the restriction tracker
+  will be called concurrently.
+
   See following documents for more details.
   * https://s.apache.org/splittable-do-fn
   * https://s.apache.org/splittable-do-fn-python-sdk
+
+  Experimental; no backwards-compatibility guarantees.
   """
 
   def current_restriction(self):
@@ -1121,54 +1126,22 @@
 
     The current restriction returned by method may be updated dynamically due
     to due to concurrent invocation of other methods of the
-    ``RestrictionTracker``, For example, ``checkpoint()``.
+    ``RestrictionTracker``, For example, ``split()``.
 
-    ** Thread safety **
+    This API is required to be implemented.
 
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+    Returns: a restriction object.
     """
     raise NotImplementedError
 
   def current_progress(self):
     """Returns a RestrictionProgress object representing the current progress.
+
+    This API is recommended to be implemented. The runner can do a better job
+    at parallel processing with better progress signals.
     """
     raise NotImplementedError
 
-  def current_watermark(self):
-    """Returns current watermark. By default, not report watermark.
-
-    TODO(BEAM-7473): Provide synchronization guarantee by using a wrapper.
-    """
-    return None
-
-  def checkpoint(self):
-    """Performs a checkpoint of the current restriction.
-
-    Signals that the current ``DoFn.process()`` call should terminate as soon as
-    possible. After this method returns, the tracker MUST refuse all future
-    claim calls, and ``RestrictionTracker.check_done()`` MUST succeed.
-
-    This invocation modifies the value returned by ``current_restriction()``
-    invocation and returns a restriction representing the rest of the work. The
-    old value of ``current_restriction()`` is equivalent to the new value of
-    ``current_restriction()`` and the return value of this method invocation
-    combined.
-
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
-    """
-
-    raise NotImplementedError
-
   def check_done(self):
     """Checks whether the restriction has been fully processed.
 
@@ -1179,13 +1152,8 @@
     remaining in the restriction when this method is invoked. Exception raised
     must have an informative error message.
 
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+    This API is required to be implemented in order to make sure no data loss
+    during SDK processing.
 
     Returns: ``True`` if current restriction has been fully processed.
     Raises:
@@ -1215,8 +1183,12 @@
     restrictions returned would be [100, 179), [179, 200) (note: current_offset
     + fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179).
 
-    It is very important for pipeline scaling and end to end pipeline execution
-    that try_split is implemented well.
+    ``fraction_of_remainder`` = 0 means a checkpoint is required.
+
+    The API is recommended to be implemented for batch pipeline given that it is
+    very important for pipeline scaling and end to end pipeline execution.
+
+    The API is required to be implemented for a streaming pipeline.
 
     Args:
       fraction_of_remainder: A hint as to the fraction of work the primary
@@ -1226,19 +1198,11 @@
     Returns:
       (primary_restriction, residual_restriction) if a split was possible,
       otherwise returns ``None``.
-
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
     """
     raise NotImplementedError
 
   def try_claim(self, position):
-    """ Attempts to claim the block of work in the current restriction
+    """Attempts to claim the block of work in the current restriction
     identified by the given position.
 
     If this succeeds, the DoFn MUST execute the entire block of work. If it
@@ -1247,40 +1211,137 @@
     work from ``DoFn.process()`` is also not allowed before the first call of
     this method).
 
+    The API is required to be implemented.
+
     Args:
       position: current position that wants to be claimed.
 
     Returns: ``True`` if the position can be claimed as current_position.
     Otherwise, returns ``False``.
-
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
     """
     raise NotImplementedError
 
-  def defer_remainder(self, watermark=None):
-    """ Invokes checkpoint() in an SDF.process().
 
-    TODO(BEAM-7472): Remove defer_remainder() once SDF.process() uses
-    ``ProcessContinuation``.
+class ThreadsafeRestrictionTracker(object):
+  """A thread-safe wrapper which wraps a `RestritionTracker`.
+
+  This wrapper guarantees synchronization of modifying restrictions across
+  multi-thread.
+  """
+
+  def __init__(self, restriction_tracker):
+    if not isinstance(restriction_tracker, RestrictionTracker):
+      raise ValueError(
+          'Initialize ThreadsafeRestrictionTracker requires'
+          'RestrictionTracker.')
+    self._restriction_tracker = restriction_tracker
+    # Records an absolute timestamp when defer_remainder is called.
+    self._deferred_timestamp = None
+    self._lock = threading.RLock()
+    self._deferred_residual = None
+    self._deferred_watermark = None
+
+  def current_restriction(self):
+    with self._lock:
+      return self._restriction_tracker.current_restriction()
+
+  def try_claim(self, position):
+    with self._lock:
+      return self._restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    """Performs self-checkpoint on current processing restriction with an
+    expected resuming time.
+
+    Self-checkpoint could happen during processing elements. When executing an
+    DoFn.process(), you may want to stop processing an element and resuming
+    later if current element has been processed quit a long time or you also
+    want to have some outputs from other elements. ``defer_remainder()`` can be
+    called on per element if needed.
 
     Args:
-      watermark
+      deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
+      time gap between now and resuming, or an absolute ``timestamp.Timestamp``
+      for resuming execution time. If the time_delay is None, the deferred work
+      will be executed as soon as possible.
     """
-    raise NotImplementedError
+
+    # Record current time for calculating deferred_time later.
+    self._deferred_timestamp = timestamp.Timestamp.now()
+    if (deferred_time and
+        not isinstance(deferred_time, timestamp.Duration) and
+        not isinstance(deferred_time, timestamp.Timestamp)):
+      raise ValueError('The timestamp of deter_remainder() should be a '
+                       'Duration or a Timestamp, or None.')
+    self._deferred_watermark = deferred_time
+    checkpoint = self.try_split(0)
+    if checkpoint:
+      _, self._deferred_residual = checkpoint
+
+  def check_done(self):
+    with self._lock:
+      return self._restriction_tracker.check_done()
+
+  def current_progress(self):
+    with self._lock:
+      return self._restriction_tracker.current_progress()
+
+  def try_split(self, fraction_of_remainder):
+    with self._lock:
+      return self._restriction_tracker.try_split(fraction_of_remainder)
 
   def deferred_status(self):
-    """ Returns deferred_residual with deferred_watermark.
+    """Returns deferred work which is produced by ``defer_remainder()``.
 
-    TODO(BEAM-7472): Remove defer_status() once SDF.process() uses
-    ``ProcessContinuation``.
+    When there is a self-checkpoint performed, the system needs to fulfill the
+    DelayedBundleApplication with deferred_work for a  ProcessBundleResponse.
+    The system calls this API to get deferred_residual with watermark together
+    to help the runner to schedule a future work.
+
+    Returns: (deferred_residual, time_delay) if having any residual, else None.
     """
-    raise NotImplementedError
+    if self._deferred_residual:
+      # If _deferred_watermark is None, create Duration(0).
+      if not self._deferred_watermark:
+        self._deferred_watermark = timestamp.Duration()
+      # If an absolute timestamp is provided, calculate the delta between
+      # the absoluted time and the time deferred_status() is called.
+      elif isinstance(self._deferred_watermark, timestamp.Timestamp):
+        self._deferred_watermark = (self._deferred_watermark -
+                                    timestamp.Timestamp.now())
+      # If a Duration is provided, the deferred time should be:
+      # provided duration - the spent time since the defer_remainder() is
+      # called.
+      elif isinstance(self._deferred_watermark, timestamp.Duration):
+        self._deferred_watermark -= (timestamp.Timestamp.now() -
+                                     self._deferred_timestamp)
+      return self._deferred_residual, self._deferred_watermark
+
+
+class RestrictionTrackerView(object):
+  """A DoFn view of thread-safe RestrictionTracker.
+
+  The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
+  exposes APIs that will be called by a ``DoFn.process()``. During execution
+  time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
+  restriction_tracker.
+  """
+
+  def __init__(self, threadsafe_restriction_tracker):
+    if not isinstance(threadsafe_restriction_tracker,
+                      ThreadsafeRestrictionTracker):
+      raise ValueError('Initialize RestrictionTrackerView requires '
+                       'ThreadsafeRestrictionTracker.')
+    self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
+
+  def current_restriction(self):
+    return self._threadsafe_restriction_tracker.current_restriction()
+
+  def try_claim(self, position):
+    return self._threadsafe_restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
 
 
 class RestrictionProgress(object):
@@ -1400,17 +1461,8 @@
                 SourceBundle(residual_weight, self._source, split_pos,
                              stop_pos))
 
-    def deferred_status(self):
-      return None
-
-    def current_watermark(self):
-      return None
-
-    def get_delegate_range_tracker(self):
-      return self._delegate_range_tracker
-
-    def get_tracking_source(self):
-      return self._source
+    def check_done(self):
+      return self._delegate_range_tracker.fraction_consumed() >= 1.0
 
   class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
     """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
@@ -1463,8 +1515,13 @@
           restriction_tracker=core.DoFn.RestrictionParam(
               _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
                   source, chunk_size))):
-        return restriction_tracker.get_tracking_source().read(
-            restriction_tracker.get_delegate_range_tracker())
+        current_restriction = restriction_tracker.current_restriction()
+        assert isinstance(current_restriction, SourceBundle)
+        tracking_source = current_restriction.source
+        start = current_restriction.start_position
+        stop = current_restriction.stop_position
+        return tracking_source.read(tracking_source.get_range_tracker(start,
+                                                                      stop))
 
     return SDFBoundedSourceDoFn(self.source)
 
diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py
index 7adb764..0a6afae 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -19,6 +19,7 @@
 
 from __future__ import absolute_import
 
+import time
 import unittest
 
 import mock
@@ -28,6 +29,9 @@
 from apache_beam.io.concat_source_test import RangeSource
 from apache_beam.io import iobase
 from apache_beam.io.iobase import SourceBundle
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.utils import timestamp
 from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -191,5 +195,87 @@
     self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
 
 
+class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
+
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
+
+  def test_defer_remainder_with_wrong_time_type(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    with self.assertRaises(ValueError):
+      threadsafe_tracker.defer_remainder(10)
+
+  def test_self_checkpoint_immediately(self):
+    restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        restriction_tracker)
+    threadsafe_tracker.defer_remainder()
+    deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
+    expected_residual = OffsetRange(0, 10)
+    self.assertEqual(deferred_residual, expected_residual)
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    self.assertEqual(deferred_time, 0)
+
+  def test_self_checkpoint_with_relative_time(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    threadsafe_tracker.defer_remainder(timestamp.Duration(100))
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation = 100 - 2 - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+  def test_self_checkpoint_with_absolute_time(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    now = timestamp.Timestamp.now()
+    schedule_time = now + timestamp.Duration(100)
+    self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
+    threadsafe_tracker.defer_remainder(schedule_time)
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation =
+    # schedule_time - the time when deferred_status is called - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+
+class RestrictionTrackerViewTest(unittest.TestCase):
+
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      iobase.RestrictionTrackerView(
+          OffsetRestrictionTracker(OffsetRange(0, 10)))
+
+  def test_api_expose(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+    current_restriction = tracker_view.current_restriction()
+    self.assertEqual(current_restriction, OffsetRange(0, 10))
+    self.assertTrue(tracker_view.try_claim(0))
+    tracker_view.defer_remainder()
+    deferred_remainder, deferred_watermark = (
+        threadsafe_tracker.deferred_status())
+    self.assertEqual(deferred_remainder, OffsetRange(1, 10))
+    self.assertEqual(deferred_watermark, timestamp.Duration())
+
+  def test_non_expose_apis(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+    with self.assertRaises(AttributeError):
+      tracker_view.check_done()
+    with self.assertRaises(AttributeError):
+      tracker_view.current_progress()
+    with self.assertRaises(AttributeError):
+      tracker_view.try_split()
+    with self.assertRaises(AttributeError):
+      tracker_view.deferred_status()
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index 0ba5b23..20bb5c1 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -19,7 +19,6 @@
 from __future__ import absolute_import
 from __future__ import division
 
-import threading
 from builtins import object
 
 from apache_beam.io.iobase import RestrictionProgress
@@ -86,104 +85,69 @@
     assert isinstance(offset_range, OffsetRange)
     self._range = offset_range
     self._current_position = None
-    self._current_watermark = None
     self._last_claim_attempt = None
-    self._deferred_residual = None
     self._checkpointed = False
-    self._lock = threading.RLock()
 
   def check_done(self):
-    with self._lock:
-      if self._last_claim_attempt < self._range.stop - 1:
-        raise ValueError(
-            'OffsetRestrictionTracker is not done since work in range [%s, %s) '
-            'has not been claimed.'
-            % (self._last_claim_attempt if self._last_claim_attempt is not None
-               else self._range.start,
-               self._range.stop))
+    if self._last_claim_attempt < self._range.stop - 1:
+      raise ValueError(
+          'OffsetRestrictionTracker is not done since work in range [%s, %s) '
+          'has not been claimed.'
+          % (self._last_claim_attempt if self._last_claim_attempt is not None
+             else self._range.start,
+             self._range.stop))
 
   def current_restriction(self):
-    with self._lock:
-      return self._range
-
-  def current_watermark(self):
-    return self._current_watermark
+    return self._range
 
   def current_progress(self):
-    with self._lock:
-      if self._current_position is None:
-        fraction = 0.0
-      elif self._range.stop == self._range.start:
-        # If self._current_position is not None, we must be done.
-        fraction = 1.0
-      else:
-        fraction = (
-            float(self._current_position - self._range.start)
-            / (self._range.stop - self._range.start))
+    if self._current_position is None:
+      fraction = 0.0
+    elif self._range.stop == self._range.start:
+      # If self._current_position is not None, we must be done.
+      fraction = 1.0
+    else:
+      fraction = (
+          float(self._current_position - self._range.start)
+          / (self._range.stop - self._range.start))
     return RestrictionProgress(fraction=fraction)
 
   def start_position(self):
-    with self._lock:
-      return self._range.start
+    return self._range.start
 
   def stop_position(self):
-    with self._lock:
-      return self._range.stop
-
-  def default_size(self):
-    return self._range.size()
+    return self._range.stop
 
   def try_claim(self, position):
-    with self._lock:
-      if self._last_claim_attempt and position <= self._last_claim_attempt:
-        raise ValueError(
-            'Positions claimed should strictly increase. Trying to claim '
-            'position %d while last claim attempt was %d.'
-            % (position, self._last_claim_attempt))
+    if self._last_claim_attempt and position <= self._last_claim_attempt:
+      raise ValueError(
+          'Positions claimed should strictly increase. Trying to claim '
+          'position %d while last claim attempt was %d.'
+          % (position, self._last_claim_attempt))
 
-      self._last_claim_attempt = position
-      if position < self._range.start:
-        raise ValueError(
-            'Position to be claimed cannot be smaller than the start position '
-            'of the range. Tried to claim position %r for the range [%r, %r)'
-            % (position, self._range.start, self._range.stop))
+    self._last_claim_attempt = position
+    if position < self._range.start:
+      raise ValueError(
+          'Position to be claimed cannot be smaller than the start position '
+          'of the range. Tried to claim position %r for the range [%r, %r)'
+          % (position, self._range.start, self._range.stop))
 
-      if position >= self._range.start and position < self._range.stop:
-        self._current_position = position
-        return True
+    if position >= self._range.start and position < self._range.stop:
+      self._current_position = position
+      return True
 
-      return False
+    return False
 
   def try_split(self, fraction_of_remainder):
-    with self._lock:
-      if not self._checkpointed:
-        if self._current_position is None:
-          cur = self._range.start - 1
-        else:
-          cur = self._current_position
-        split_point = (
-            cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
-        if split_point < self._range.stop:
-          self._range, residual_range = self._range.split_at(split_point)
-          return self._range, residual_range
-
-  # TODO(SDF): Replace all calls with try_claim(0).
-  def checkpoint(self):
-    with self._lock:
-      # If self._current_position is 'None' no records have been claimed so
-      # residual should start from self._range.start.
+    if not self._checkpointed:
       if self._current_position is None:
-        end_position = self._range.start
+        cur = self._range.start - 1
       else:
-        end_position = self._current_position + 1
-      self._range, residual_range = self._range.split_at(end_position)
-      return residual_range
-
-  def defer_remainder(self, watermark=None):
-    with self._lock:
-      self._deferred_watermark = watermark or self._current_watermark
-      self._deferred_residual = self.checkpoint()
-
-  def deferred_status(self):
-    if self._deferred_residual:
-      return (self._deferred_residual, self._deferred_watermark)
+        cur = self._current_position
+      split_point = (
+          cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
+      if split_point < self._range.stop:
+        if fraction_of_remainder == 0:
+          self._checkpointed = True
+        self._range, residual_range = self._range.split_at(split_point)
+        return self._range, residual_range
diff --git a/sdks/python/apache_beam/io/restriction_trackers_test.py b/sdks/python/apache_beam/io/restriction_trackers_test.py
index 459b039..4a57d98 100644
--- a/sdks/python/apache_beam/io/restriction_trackers_test.py
+++ b/sdks/python/apache_beam/io/restriction_trackers_test.py
@@ -81,14 +81,14 @@
 
   def test_checkpoint_unstarted(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 100), tracker.current_restriction())
     self.assertEqual(OffsetRange(100, 200), checkpoint)
 
   def test_checkpoint_just_started(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
     self.assertTrue(tracker.try_claim(100))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 101), tracker.current_restriction())
     self.assertEqual(OffsetRange(101, 200), checkpoint)
 
@@ -96,7 +96,7 @@
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
     self.assertTrue(tracker.try_claim(105))
     self.assertTrue(tracker.try_claim(110))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 111), tracker.current_restriction())
     self.assertEqual(OffsetRange(111, 200), checkpoint)
 
@@ -105,9 +105,9 @@
     self.assertTrue(tracker.try_claim(105))
     self.assertTrue(tracker.try_claim(110))
     self.assertTrue(tracker.try_claim(199))
-    checkpoint = tracker.checkpoint()
+    checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 200), tracker.current_restriction())
-    self.assertEqual(OffsetRange(200, 200), checkpoint)
+    self.assertEqual(None, checkpoint)
 
   def test_checkpoint_after_failed_claim(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
@@ -116,7 +116,7 @@
     self.assertTrue(tracker.try_claim(160))
     self.assertFalse(tracker.try_claim(240))
 
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertTrue(OffsetRange(100, 161), tracker.current_restriction())
     self.assertTrue(OffsetRange(161, 200), checkpoint)
 
diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd
new file mode 100644
index 0000000..0204da8
--- /dev/null
+++ b/sdks/python/apache_beam/metrics/cells.pxd
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+cimport libc.stdint
+
+
+cdef class MetricCell(object):
+  cdef object _lock
+  cpdef bint update(self, value) except -1
+
+
+cdef class CounterCell(MetricCell):
+  cdef readonly libc.stdint.int64_t value
+
+  @cython.locals(ivalue=libc.stdint.int64_t)
+  cpdef bint update(self, value) except -1
+
+
+cdef class DistributionCell(MetricCell):
+  cdef readonly DistributionData data
+
+  @cython.locals(ivalue=libc.stdint.int64_t)
+  cdef inline bint _update(self, value) except -1
+
+
+cdef class GaugeCell(MetricCell):
+  cdef readonly object data
+
+
+cdef class DistributionData(object):
+  cdef readonly libc.stdint.int64_t sum
+  cdef readonly libc.stdint.int64_t count
+  cdef readonly libc.stdint.int64_t min
+  cdef readonly libc.stdint.int64_t max
diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py
index e7336e4..d30dd2a 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -30,12 +30,16 @@
 
 from google.protobuf import timestamp_pb2
 
-from apache_beam.metrics.metricbase import Counter
-from apache_beam.metrics.metricbase import Distribution
-from apache_beam.metrics.metricbase import Gauge
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import metrics_pb2
 
+try:
+  import cython
+except ImportError:
+  class fake_cython:
+    compiled = False
+  globals()['cython'] = fake_cython
+
 __all__ = ['DistributionResult', 'GaugeResult']
 
 
@@ -52,11 +56,17 @@
   def __init__(self):
     self._lock = threading.Lock()
 
+  def update(self, value):
+    raise NotImplementedError
+
   def get_cumulative(self):
     raise NotImplementedError
 
+  def __reduce__(self):
+    raise NotImplementedError
 
-class CounterCell(Counter, MetricCell):
+
+class CounterCell(MetricCell):
   """For internal use only; no backwards-compatibility guarantees.
 
   Tracks the current value and delta of a counter metric.
@@ -80,27 +90,41 @@
     return result
 
   def inc(self, n=1):
-    with self._lock:
-      self.value += n
+    self.update(n)
+
+  def dec(self, n=1):
+    self.update(-n)
+
+  def update(self, value):
+    if cython.compiled:
+      ivalue = value
+      # We hold the GIL, no need for another lock.
+      self.value += ivalue
+    else:
+      with self._lock:
+        self.value += value
 
   def get_cumulative(self):
     with self._lock:
       return self.value
 
-  def to_runner_api_monitoring_info(self):
-    """Returns a Metric with this counter value for use in a MonitoringInfo."""
-    # TODO(ajamato): Update this code to be consistent with Gauges
-    # and Distributions. Since there is no CounterData class this method
-    # was added to CounterCell. Consider adding a CounterData class or
-    # removing the GaugeData and DistributionData classes.
-    return metrics_pb2.Metric(
-        counter_data=metrics_pb2.CounterData(
-            int64_value=self.get_cumulative()
-        )
-    )
+  def to_runner_api_user_metric(self, metric_name):
+    return beam_fn_api_pb2.Metrics.User(
+        metric_name=metric_name.to_runner_api(),
+        counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
+            value=self.value))
+
+  def to_runner_api_monitoring_info(self, name, transform_id):
+    from apache_beam.metrics import monitoring_infos
+    return monitoring_infos.int64_user_counter(
+        name.namespace, name.name,
+        metrics_pb2.Metric(
+            counter_data=metrics_pb2.CounterData(
+                int64_value=self.get_cumulative())),
+        ptransform=transform_id)
 
 
-class DistributionCell(Distribution, MetricCell):
+class DistributionCell(MetricCell):
   """For internal use only; no backwards-compatibility guarantees.
 
   Tracks the current value and delta for a distribution metric.
@@ -124,26 +148,43 @@
     return result
 
   def update(self, value):
-    with self._lock:
+    if cython.compiled:
+      # We will hold the GIL throughout the entire _update.
       self._update(value)
+    else:
+      with self._lock:
+        self._update(value)
 
   def _update(self, value):
-    value = int(value)
-    self.data.count += 1
-    self.data.sum += value
-    self.data.min = (value
-                     if self.data.min is None or self.data.min > value
-                     else self.data.min)
-    self.data.max = (value
-                     if self.data.max is None or self.data.max < value
-                     else self.data.max)
+    if cython.compiled:
+      ivalue = value
+    else:
+      ivalue = int(value)
+    self.data.count = self.data.count + 1
+    self.data.sum = self.data.sum + ivalue
+    if ivalue < self.data.min:
+      self.data.min = ivalue
+    if ivalue > self.data.max:
+      self.data.max = ivalue
 
   def get_cumulative(self):
     with self._lock:
       return self.data.get_cumulative()
 
+  def to_runner_api_user_metric(self, metric_name):
+    return beam_fn_api_pb2.Metrics.User(
+        metric_name=metric_name.to_runner_api(),
+        distribution_data=self.get_cumulative().to_runner_api())
 
-class GaugeCell(Gauge, MetricCell):
+  def to_runner_api_monitoring_info(self, name, transform_id):
+    from apache_beam.metrics import monitoring_infos
+    return monitoring_infos.int64_user_distribution(
+        name.namespace, name.name,
+        self.get_cumulative().to_runner_api_monitoring_info(),
+        ptransform=transform_id)
+
+
+class GaugeCell(MetricCell):
   """For internal use only; no backwards-compatibility guarantees.
 
   Tracks the current value and delta for a gauge metric.
@@ -167,6 +208,9 @@
     return result
 
   def set(self, value):
+    self.update(value)
+
+  def update(self, value):
     value = int(value)
     with self._lock:
       # Set the value directly without checking timestamp, because
@@ -178,6 +222,18 @@
     with self._lock:
       return self.data.get_cumulative()
 
+  def to_runner_api_user_metric(self, metric_name):
+    return beam_fn_api_pb2.Metrics.User(
+        metric_name=metric_name.to_runner_api(),
+        gauge_data=self.get_cumulative().to_runner_api())
+
+  def to_runner_api_monitoring_info(self, name, transform_id):
+    from apache_beam.metrics import monitoring_infos
+    return monitoring_infos.int64_user_gauge(
+        name.namespace, name.name,
+        self.get_cumulative().to_runner_api_monitoring_info(),
+        ptransform=transform_id)
+
 
 class DistributionResult(object):
   """The result of a Distribution metric."""
@@ -198,7 +254,7 @@
     return not self == other
 
   def __repr__(self):
-    return '<DistributionResult(sum={}, count={}, min={}, max={})>'.format(
+    return 'DistributionResult(sum={}, count={}, min={}, max={})'.format(
         self.sum,
         self.count,
         self.min,
@@ -206,11 +262,11 @@
 
   @property
   def max(self):
-    return self.data.max
+    return self.data.max if self.data.count else None
 
   @property
   def min(self):
-    return self.data.min
+    return self.data.min if self.data.count else None
 
   @property
   def count(self):
@@ -340,10 +396,15 @@
   by other than the DistributionCell that contains it.
   """
   def __init__(self, sum, count, min, max):
-    self.sum = sum
-    self.count = count
-    self.min = min
-    self.max = max
+    if count:
+      self.sum = sum
+      self.count = count
+      self.min = min
+      self.max = max
+    else:
+      self.sum = self.count = 0
+      self.min = 2**63 - 1
+      self.max = -2**63
 
   def __eq__(self, other):
     return (self.sum == other.sum and
@@ -359,7 +420,7 @@
     return not self == other
 
   def __repr__(self):
-    return '<DistributionData(sum={}, count={}, min={}, max={})>'.format(
+    return 'DistributionData(sum={}, count={}, min={}, max={})'.format(
         self.sum,
         self.count,
         self.min,
@@ -372,15 +433,11 @@
     if other is None:
       return self
 
-    new_min = (None if self.min is None and other.min is None else
-               min(x for x in (self.min, other.min) if x is not None))
-    new_max = (None if self.max is None and other.max is None else
-               max(x for x in (self.max, other.max) if x is not None))
     return DistributionData(
         self.sum + other.sum,
         self.count + other.count,
-        new_min,
-        new_max)
+        self.min if self.min < other.min else other.min,
+        self.max if self.max > other.max else other.max)
 
   @staticmethod
   def singleton(value):
@@ -449,7 +506,7 @@
   """
   @staticmethod
   def identity_element():
-    return DistributionData(0, 0, None, None)
+    return DistributionData(0, 0, 2**63 - 1, -2**63)
 
   def combine(self, x, y):
     return x.combine(y)
diff --git a/sdks/python/apache_beam/metrics/execution.pxd b/sdks/python/apache_beam/metrics/execution.pxd
index 74b34fb..6e1cbb0 100644
--- a/sdks/python/apache_beam/metrics/execution.pxd
+++ b/sdks/python/apache_beam/metrics/execution.pxd
@@ -16,10 +16,30 @@
 #
 
 cimport cython
+cimport libc.stdint
+
+from apache_beam.metrics.cells cimport MetricCell
+
+
+cdef object get_current_tracker
+
+
+cdef class _TypedMetricName(object):
+  cdef readonly object cell_type
+  cdef readonly object metric_name
+  cdef readonly object fast_name
+  cdef libc.stdint.int64_t _hash
+
+
+cdef object _DEFAULT
+
+
+cdef class MetricUpdater(object):
+  cdef _TypedMetricName typed_metric_name
+  cdef object default
 
 
 cdef class MetricsContainer(object):
   cdef object step_name
-  cdef public object counters
-  cdef public object distributions
-  cdef public object gauges
+  cdef public dict metrics
+  cpdef MetricCell get_metric_cell(self, metric_key)
diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py
index 91fe2f8..6918914 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -33,14 +33,13 @@
 from __future__ import absolute_import
 
 from builtins import object
-from collections import defaultdict
 
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.cells import CounterCell
 from apache_beam.metrics.cells import DistributionCell
 from apache_beam.metrics.cells import GaugeCell
-from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.runners.worker import statesampler
+from apache_beam.runners.worker.statesampler import get_current_tracker
 
 
 class MetricKey(object):
@@ -150,88 +149,117 @@
 MetricsEnvironment = _MetricsEnvironment()
 
 
+class _TypedMetricName(object):
+  """Like MetricName, but also stores the cell type of the metric."""
+  def __init__(self, cell_type, metric_name):
+    self.cell_type = cell_type
+    self.metric_name = metric_name
+    if isinstance(metric_name, str):
+      self.fast_name = metric_name
+    else:
+      self.fast_name = '%d_%s%s' % (
+          len(metric_name.name), metric_name.name, metric_name.namespace)
+    # Cached for speed, as this is used as a key for every counter update.
+    self._hash = hash((cell_type, self.fast_name))
+
+  def __eq__(self, other):
+    return self is other or (
+        self.cell_type == other.cell_type and self.fast_name == other.fast_name)
+
+  def __ne__(self, other):
+    return not self == other
+
+  def __hash__(self):
+    return self._hash
+
+  def __reduce__(self):
+    return _TypedMetricName, (self.cell_type, self.metric_name)
+
+
+_DEFAULT = None
+
+
+class MetricUpdater(object):
+  """A callable that updates the metric as quickly as possible."""
+  def __init__(self, cell_type, metric_name, default=None):
+    self.typed_metric_name = _TypedMetricName(cell_type, metric_name)
+    self.default = default
+
+  def __call__(self, value=_DEFAULT):
+    if value is _DEFAULT:
+      if self.default is _DEFAULT:
+        raise ValueError(
+            'Missing value for update of %s' % self.metric_name)
+      value = self.default
+    tracker = get_current_tracker()
+    if tracker is not None:
+      tracker.update_metric(self.typed_metric_name, value)
+
+  def __reduce__(self):
+    return MetricUpdater, (
+        self.typed_metric_name.cell_type,
+        self.typed_metric_name.metric_name,
+        self.default)
+
+
 class MetricsContainer(object):
   """Holds the metrics of a single step and a single bundle."""
   def __init__(self, step_name):
     self.step_name = step_name
-    self.counters = defaultdict(lambda: CounterCell())
-    self.distributions = defaultdict(lambda: DistributionCell())
-    self.gauges = defaultdict(lambda: GaugeCell())
+    self.metrics = dict()
 
   def get_counter(self, metric_name):
-    return self.counters[metric_name]
+    return self.get_metric_cell(_TypedMetricName(CounterCell, metric_name))
 
   def get_distribution(self, metric_name):
-    return self.distributions[metric_name]
+    return self.get_metric_cell(_TypedMetricName(DistributionCell, metric_name))
 
   def get_gauge(self, metric_name):
-    return self.gauges[metric_name]
+    return self.get_metric_cell(_TypedMetricName(GaugeCell, metric_name))
+
+  def get_metric_cell(self, typed_metric_name):
+    cell = self.metrics.get(typed_metric_name, None)
+    if cell is None:
+      cell = self.metrics[typed_metric_name] = typed_metric_name.cell_type()
+    return cell
 
   def get_cumulative(self):
     """Return MetricUpdates with cumulative values of all metrics in container.
 
     This returns all the cumulative values for all metrics.
     """
-    counters = {MetricKey(self.step_name, k): v.get_cumulative()
-                for k, v in self.counters.items()}
+    counters = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+                for k, v in self.metrics.items()
+                if k.cell_type == CounterCell}
 
-    distributions = {MetricKey(self.step_name, k): v.get_cumulative()
-                     for k, v in self.distributions.items()}
+    distributions = {
+        MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+        for k, v in self.metrics.items()
+        if k.cell_type == DistributionCell}
 
-    gauges = {MetricKey(self.step_name, k): v.get_cumulative()
-              for k, v in self.gauges.items()}
+    gauges = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+              for k, v in self.metrics.items()
+              if k.cell_type == GaugeCell}
 
     return MetricUpdates(counters, distributions, gauges)
 
   def to_runner_api(self):
-    return (
-        [beam_fn_api_pb2.Metrics.User(
-            metric_name=k.to_runner_api(),
-            counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
-                value=v.get_cumulative()))
-         for k, v in self.counters.items()] +
-        [beam_fn_api_pb2.Metrics.User(
-            metric_name=k.to_runner_api(),
-            distribution_data=v.get_cumulative().to_runner_api())
-         for k, v in self.distributions.items()] +
-        [beam_fn_api_pb2.Metrics.User(
-            metric_name=k.to_runner_api(),
-            gauge_data=v.get_cumulative().to_runner_api())
-         for k, v in self.gauges.items()]
-    )
+    return [cell.to_runner_api_user_metric(key.metric_name)
+            for key, cell in self.metrics.items()]
 
   def to_runner_api_monitoring_infos(self, transform_id):
     """Returns a list of MonitoringInfos for the metrics in this container."""
-    all_user_metrics = []
-    for k, v in self.counters.items():
-      all_user_metrics.append(monitoring_infos.int64_user_counter(
-          k.namespace, k.name,
-          v.to_runner_api_monitoring_info(),
-          ptransform=transform_id
-      ))
-
-    for k, v in self.distributions.items():
-      all_user_metrics.append(monitoring_infos.int64_user_distribution(
-          k.namespace, k.name,
-          v.get_cumulative().to_runner_api_monitoring_info(),
-          ptransform=transform_id
-      ))
-
-    for k, v in self.gauges.items():
-      all_user_metrics.append(monitoring_infos.int64_user_gauge(
-          k.namespace, k.name,
-          v.get_cumulative().to_runner_api_monitoring_info(),
-          ptransform=transform_id
-      ))
+    all_user_metrics = [
+        cell.to_runner_api_monitoring_info(key.metric_name, transform_id)
+        for key, cell in self.metrics.items()]
     return {monitoring_infos.to_key(mi) : mi for mi in all_user_metrics}
 
   def reset(self):
-    for counter in self.counters.values():
-      counter.reset()
-    for distribution in self.distributions.values():
-      distribution.reset()
-    for gauge in self.gauges.values():
-      gauge.reset()
+    for metric in self.metrics.values():
+      metric.reset()
+
+  def __reduce__(self):
+    raise NotImplementedError
 
 
 class MetricUpdates(object):
diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py
index 9af1696..fc363a4 100644
--- a/sdks/python/apache_beam/metrics/execution_test.py
+++ b/sdks/python/apache_beam/metrics/execution_test.py
@@ -73,12 +73,6 @@
 
 
 class TestMetricsContainer(unittest.TestCase):
-  def test_create_new_counter(self):
-    mc = MetricsContainer('astep')
-    self.assertFalse(MetricName('namespace', 'name') in mc.counters)
-    mc.get_counter(MetricName('namespace', 'name'))
-    self.assertTrue(MetricName('namespace', 'name') in mc.counters)
-
   def test_add_to_counter(self):
     mc = MetricsContainer('astep')
     counter = mc.get_counter(MetricName('namespace', 'name'))
diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py
index acd4771..8bbe191 100644
--- a/sdks/python/apache_beam/metrics/metric.py
+++ b/sdks/python/apache_beam/metrics/metric.py
@@ -29,7 +29,8 @@
 import inspect
 from builtins import object
 
-from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics import cells
+from apache_beam.metrics.execution import MetricUpdater
 from apache_beam.metrics.metricbase import Counter
 from apache_beam.metrics.metricbase import Distribution
 from apache_beam.metrics.metricbase import Gauge
@@ -101,11 +102,7 @@
     def __init__(self, metric_name):
       super(Metrics.DelegatingCounter, self).__init__()
       self.metric_name = metric_name
-
-    def inc(self, n=1):
-      container = MetricsEnvironment.current_container()
-      if container is not None:
-        container.get_counter(self.metric_name).inc(n)
+      self.inc = MetricUpdater(cells.CounterCell, metric_name, default=1)
 
   class DelegatingDistribution(Distribution):
     """Metrics Distribution Delegates functionality to MetricsEnvironment."""
@@ -113,11 +110,7 @@
     def __init__(self, metric_name):
       super(Metrics.DelegatingDistribution, self).__init__()
       self.metric_name = metric_name
-
-    def update(self, value):
-      container = MetricsEnvironment.current_container()
-      if container is not None:
-        container.get_distribution(self.metric_name).update(value)
+      self.update = MetricUpdater(cells.DistributionCell, metric_name)
 
   class DelegatingGauge(Gauge):
     """Metrics Gauge that Delegates functionality to MetricsEnvironment."""
@@ -125,11 +118,7 @@
     def __init__(self, metric_name):
       super(Metrics.DelegatingGauge, self).__init__()
       self.metric_name = metric_name
-
-    def set(self, value):
-      container = MetricsEnvironment.current_container()
-      if container is not None:
-        container.get_gauge(self.metric_name).set(value)
+      self.set = MetricUpdater(cells.GaugeCell, metric_name)
 
 
 class MetricResults(object):
diff --git a/sdks/python/apache_beam/metrics/metric_test.py b/sdks/python/apache_beam/metrics/metric_test.py
index 6e8ee08..cb18dc7 100644
--- a/sdks/python/apache_beam/metrics/metric_test.py
+++ b/sdks/python/apache_beam/metrics/metric_test.py
@@ -130,31 +130,36 @@
     statesampler.set_current_tracker(sampler)
     state1 = sampler.scoped_state('mystep', 'myState',
                                   metrics_container=MetricsContainer('mystep'))
-    sampler.start()
-    with state1:
-      counter_ns = 'aCounterNamespace'
-      distro_ns = 'aDistributionNamespace'
-      name = 'a_name'
-      counter = Metrics.counter(counter_ns, name)
-      distro = Metrics.distribution(distro_ns, name)
-      counter.inc(10)
-      counter.dec(3)
-      distro.update(10)
-      distro.update(2)
-      self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
-      self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
 
-      del distro
-      del counter
+    try:
+      sampler.start()
+      with state1:
+        counter_ns = 'aCounterNamespace'
+        distro_ns = 'aDistributionNamespace'
+        name = 'a_name'
+        counter = Metrics.counter(counter_ns, name)
+        distro = Metrics.distribution(distro_ns, name)
+        counter.inc(10)
+        counter.dec(3)
+        distro.update(10)
+        distro.update(2)
+        self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
+        self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
 
-      container = MetricsEnvironment.current_container()
-      self.assertEqual(
-          container.counters[MetricName(counter_ns, name)].get_cumulative(),
-          7)
-      self.assertEqual(
-          container.distributions[MetricName(distro_ns, name)].get_cumulative(),
-          DistributionData(12, 2, 2, 10))
-    sampler.stop()
+        del distro
+        del counter
+
+        container = MetricsEnvironment.current_container()
+        self.assertEqual(
+            container.get_counter(
+                MetricName(counter_ns, name)).get_cumulative(),
+            7)
+        self.assertEqual(
+            container.get_distribution(
+                MetricName(distro_ns, name)).get_cumulative(),
+            DistributionData(12, 2, 2, 10))
+    finally:
+      sampler.stop()
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index 2ffe432..37e05bf 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -42,6 +42,8 @@
   cdef object key_arg_name
   cdef object restriction_provider
   cdef object restriction_provider_arg_name
+  cdef object watermark_estimator
+  cdef object watermark_estimator_arg_name
 
 
 cdef class DoFnSignature(object):
@@ -91,7 +93,9 @@
   cdef bint cache_globally_windowed_args
   cdef object process_method
   cdef bint is_splittable
-  cdef object restriction_tracker
+  cdef object threadsafe_restriction_tracker
+  cdef object watermark_estimator
+  cdef object watermark_estimator_param
   cdef WindowedValue current_windowed_value
   cdef bint is_key_param_required
 
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 3e14f3b..8632cfd 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 # cython: profile=True
 
 """Worker operations executor.
@@ -167,6 +166,8 @@
     self.key_arg_name = None
     self.restriction_provider = None
     self.restriction_provider_arg_name = None
+    self.watermark_estimator = None
+    self.watermark_estimator_arg_name = None
 
     for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
       if isinstance(v, core.DoFn.StateParam):
@@ -184,6 +185,9 @@
       elif isinstance(v, core.DoFn.RestrictionParam):
         self.restriction_provider = v.restriction_provider
         self.restriction_provider_arg_name = kw
+      elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
+        self.watermark_estimator = v.watermark_estimator
+        self.watermark_estimator_arg_name = kw
 
   def invoke_timer_callback(self,
                             user_state_context,
@@ -264,6 +268,9 @@
   def get_restriction_provider(self):
     return self.process_method.restriction_provider
 
+  def get_watermark_estimator(self):
+    return self.process_method.watermark_estimator
+
   def _validate(self):
     self._validate_process()
     self._validate_bundle_method(self.start_bundle_method)
@@ -458,7 +465,11 @@
         signature.is_stateful_dofn())
     self.user_state_context = user_state_context
     self.is_splittable = signature.is_splittable_dofn()
-    self.restriction_tracker = None
+    self.watermark_estimator = self.signature.get_watermark_estimator()
+    self.watermark_estimator_param = (
+        self.signature.process_method.watermark_estimator_arg_name
+        if self.watermark_estimator else None)
+    self.threadsafe_restriction_tracker = None
     self.current_windowed_value = None
     self.bundle_finalizer_param = bundle_finalizer_param
     self.is_key_param_required = False
@@ -569,15 +580,24 @@
         raise ValueError(
             'A RestrictionTracker %r was provided but DoFn does not have a '
             'RestrictionTrackerParam defined' % restriction_tracker)
-      additional_kwargs[restriction_tracker_param] = restriction_tracker
+      from apache_beam.io import iobase
+      self.threadsafe_restriction_tracker = iobase.ThreadsafeRestrictionTracker(
+          restriction_tracker)
+      additional_kwargs[restriction_tracker_param] = (
+          iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker))
+
+      if self.watermark_estimator:
+        # The watermark estimator needs to be reset for every element.
+        self.watermark_estimator.reset()
+        additional_kwargs[self.watermark_estimator_param] = (
+            self.watermark_estimator)
       try:
         self.current_windowed_value = windowed_value
-        self.restriction_tracker = restriction_tracker
         return self._invoke_process_per_window(
             windowed_value, additional_args, additional_kwargs,
             output_processor)
       finally:
-        self.restriction_tracker = None
+        self.threadsafe_restriction_tracker = None
         self.current_windowed_value = windowed_value
 
     elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
@@ -664,24 +684,34 @@
           windowed_value, self.process_method(*args_for_process))
 
     if self.is_splittable:
-      deferred_status = self.restriction_tracker.deferred_status()
+      # TODO: Consider calling check_done right after SDF.Process() finishing.
+      # In order to do this, we need to know that current invoking dofn is
+      # ProcessSizedElementAndRestriction.
+      self.threadsafe_restriction_tracker.check_done()
+      deferred_status = self.threadsafe_restriction_tracker.deferred_status()
+      output_watermark = None
+      if self.watermark_estimator:
+        output_watermark = self.watermark_estimator.current_watermark()
       if deferred_status:
         deferred_restriction, deferred_watermark = deferred_status
         element = windowed_value.value
         size = self.signature.get_restriction_provider().restriction_size(
             element, deferred_restriction)
-        return (
+        return ((
             windowed_value.with_value(((element, deferred_restriction), size)),
-            deferred_watermark)
+            output_watermark), deferred_watermark)
 
   def try_split(self, fraction):
-    restriction_tracker = self.restriction_tracker
+    restriction_tracker = self.threadsafe_restriction_tracker
     current_windowed_value = self.current_windowed_value
     if restriction_tracker and current_windowed_value:
       # Temporary workaround for [BEAM-7473]: get current_watermark before
       # split, in case watermark gets advanced before getting split results.
       # In worst case, current_watermark is always stale, which is ok.
-      current_watermark = restriction_tracker.current_watermark()
+      if self.watermark_estimator:
+        current_watermark = self.watermark_estimator.current_watermark()
+      else:
+        current_watermark = None
       split = restriction_tracker.try_split(fraction)
       if split:
         primary, residual = split
@@ -690,15 +720,13 @@
         primary_size = restriction_provider.restriction_size(element, primary)
         residual_size = restriction_provider.restriction_size(element, residual)
         return (
-            (self.current_windowed_value.with_value(
-                ((element, primary), primary_size)),
-             None),
-            (self.current_windowed_value.with_value(
-                ((element, residual), residual_size)),
-             current_watermark))
+            ((self.current_windowed_value.with_value((
+                (element, primary), primary_size)), None), None),
+            ((self.current_windowed_value.with_value((
+                (element, residual), residual_size)), current_watermark), None))
 
   def current_element_progress(self):
-    restriction_tracker = self.restriction_tracker
+    restriction_tracker = self.threadsafe_restriction_tracker
     if restriction_tracker:
       return restriction_tracker.current_progress()
 
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index 4a5fef4..50b79e8 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -58,6 +58,7 @@
 from apache_beam.runners.runner import PipelineRunner
 from apache_beam.runners.runner import PipelineState
 from apache_beam.runners.runner import PValueCache
+from apache_beam.runners.utils import is_interactive
 from apache_beam.transforms import window
 from apache_beam.transforms.display import DisplayData
 from apache_beam.typehints import typehints
@@ -364,6 +365,16 @@
 
   def run_pipeline(self, pipeline, options):
     """Remotely executes entire pipeline or parts reachable from node."""
+    # Label goog-dataflow-notebook if job is started from notebook.
+    _, is_in_notebook = is_interactive()
+    if is_in_notebook:
+      notebook_version = ('goog-dataflow-notebook=' +
+                          beam.version.__version__.replace('.', '_'))
+      if options.view_as(GoogleCloudOptions).labels:
+        options.view_as(GoogleCloudOptions).labels.append(notebook_version)
+      else:
+        options.view_as(GoogleCloudOptions).labels = [notebook_version]
+
     # Import here to avoid adding the dependency for local running scenarios.
     try:
       # pylint: disable=wrong-import-order, wrong-import-position
@@ -622,9 +633,15 @@
     debug_options = options.view_as(DebugOptions)
     use_fn_api = (debug_options.experiments and
                   'beam_fn_api' in debug_options.experiments)
+    use_streaming_engine = (
+        debug_options.experiments and
+        'enable_streaming_engine' in debug_options.experiments and
+        'enable_windmill_service' in debug_options.experiments)
+
     step = self._add_step(
         TransformNames.READ, transform_node.full_label, transform_node)
-    if standard_options.streaming and not use_fn_api:
+    if (standard_options.streaming and
+        (not use_fn_api or not use_streaming_engine)):
       step.add_property(PropertyNames.FORMAT, 'pubsub')
       step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, '_starting_signal/')
     else:
@@ -1195,6 +1212,52 @@
          PropertyNames.STEP_NAME: input_step.proto.name,
          PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})
 
+  def run_TestStream(self, transform_node, options):
+    from apache_beam.portability.api import beam_runner_api_pb2
+    from apache_beam.testing.test_stream import ElementEvent
+    from apache_beam.testing.test_stream import ProcessingTimeEvent
+    from apache_beam.testing.test_stream import WatermarkEvent
+    standard_options = options.view_as(StandardOptions)
+    if not standard_options.streaming:
+      raise ValueError('TestStream is currently available for use '
+                       'only in streaming pipelines.')
+
+    transform = transform_node.transform
+    step = self._add_step(TransformNames.READ, transform_node.full_label,
+                          transform_node)
+    step.add_property(PropertyNames.FORMAT, 'test_stream')
+    test_stream_payload = beam_runner_api_pb2.TestStreamPayload()
+    # TestStream source doesn't do any decoding of elements,
+    # so we won't set test_stream_payload.coder_id.
+    output_coder = transform._infer_output_coder()  # pylint: disable=protected-access
+    for event in transform.events:
+      new_event = test_stream_payload.events.add()
+      if isinstance(event, ElementEvent):
+        for tv in event.timestamped_values:
+          element = new_event.element_event.elements.add()
+          element.encoded_element = output_coder.encode(tv.value)
+          element.timestamp = tv.timestamp.micros
+      elif isinstance(event, ProcessingTimeEvent):
+        new_event.processing_time_event.advance_duration = (
+            event.advance_by.micros)
+      elif isinstance(event, WatermarkEvent):
+        new_event.watermark_event.new_watermark = event.new_watermark.micros
+    serialized_payload = self.byte_array_to_json_string(
+        test_stream_payload.SerializeToString())
+    step.add_property(PropertyNames.SERIALIZED_TEST_STREAM, serialized_payload)
+
+    step.encoding = self._get_encoded_output_coder(transform_node)
+    step.add_property(PropertyNames.OUTPUT_INFO, [{
+        PropertyNames.USER_NAME:
+            ('%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
+        PropertyNames.ENCODING: step.encoding,
+        PropertyNames.OUTPUT_NAME: PropertyNames.OUT
+    }])
+
+  # We must mark this method as not a test or else its name is a matcher for
+  # nosetest tests.
+  run_TestStream.__test__ = False
+
   @classmethod
   def serialize_windowing_strategy(cls, windowing):
     from apache_beam.runners import pipeline_context
diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py
index eacce15..ec6f988 100644
--- a/sdks/python/apache_beam/runners/dataflow/internal/names.py
+++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py
@@ -45,8 +45,8 @@
 
 # TODO(BEAM-5939): Remove these shared names once Dataflow worker is updated.
 PICKLED_MAIN_SESSION_FILE = 'pickled_main_session'
-STAGED_PIPELINE_FILENAME = "pipeline.pb"
-STAGED_PIPELINE_URL_METADATA_FIELD = "pipeline_url"
+STAGED_PIPELINE_FILENAME = 'pipeline.pb'
+STAGED_PIPELINE_URL_METADATA_FIELD = 'pipeline_url'
 
 # Package names for different distributions
 BEAM_PACKAGE_NAME = 'apache-beam'
@@ -61,7 +61,8 @@
 class TransformNames(object):
   """For internal use only; no backwards-compatibility guarantees.
 
-  Transform strings as they are expected in the CloudWorkflow protos."""
+  Transform strings as they are expected in the CloudWorkflow protos.
+  """
   COLLECTION_TO_SINGLETON = 'CollectionToSingleton'
   COMBINE = 'CombineValues'
   CREATE_PCOLLECTION = 'CreateCollection'
@@ -75,7 +76,8 @@
 class PropertyNames(object):
   """For internal use only; no backwards-compatibility guarantees.
 
-  Property strings as they are expected in the CloudWorkflow protos."""
+  Property strings as they are expected in the CloudWorkflow protos.
+  """
   BIGQUERY_CREATE_DISPOSITION = 'create_disposition'
   BIGQUERY_DATASET = 'dataset'
   BIGQUERY_EXPORT_FORMAT = 'bigquery_export_format'
@@ -113,6 +115,7 @@
   SERIALIZED_FN = 'serialized_fn'
   SHARD_NAME_TEMPLATE = 'shard_template'
   SOURCE_STEP_INPUT = 'custom_source_step_input'
+  SERIALIZED_TEST_STREAM = 'serialized_test_stream'
   STEP_NAME = 'step_name'
   USER_FN = 'user_fn'
   USER_NAME = 'user_name'
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index 946ef34..fd04d4c 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -51,6 +51,9 @@
   def create_tracker(self, restriction):
     return OffsetRestrictionTracker(restriction)
 
+  def restriction_size(self, element, restriction):
+    return restriction.size()
+
 
 class ReadFiles(DoFn):
 
@@ -63,12 +66,11 @@
       restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()),
       *args, **kwargs):
     file_name = element
-    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
 
     with open(file_name, 'rb') as file:
-      pos = restriction_tracker.start_position()
-      if restriction_tracker.start_position() > 0:
-        file.seek(restriction_tracker.start_position() - 1)
+      pos = restriction_tracker.current_restriction().start
+      if restriction_tracker.current_restriction().start > 0:
+        file.seek(restriction_tracker.current_restriction().start - 1)
         line = file.readline()
         pos = pos - 1 + len(line)
 
@@ -104,6 +106,9 @@
   def split(self, element, restriction):
     return [restriction,]
 
+  def restriction_size(self, element, restriction):
+    return restriction.size()
+
 
 class ExpandStrings(DoFn):
 
@@ -118,10 +123,9 @@
     side.extend(side1)
     side.extend(side2)
     side.extend(side3)
-    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
     side = list(side)
-    for i in range(restriction_tracker.start_position(),
-                   restriction_tracker.stop_position()):
+    for i in range(restriction_tracker.current_restriction().start,
+                   restriction_tracker.current_restriction().stop):
       if restriction_tracker.try_claim(i):
         if not side:
           yield (
diff --git a/sdks/python/apache_beam/runners/interactive/display/display_manager.py b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
index 84025f9..c6ead9d 100644
--- a/sdks/python/apache_beam/runners/interactive/display/display_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
@@ -32,17 +32,18 @@
 
 try:
   import IPython  # pylint: disable=import-error
+  from IPython import get_ipython  # pylint: disable=import-error
+  from IPython.display import display as ip_display  # pylint: disable=import-error
   # _display_progress defines how outputs are printed on the frontend.
-  _display_progress = IPython.display.display
+  _display_progress = ip_display
 
   def _formatter(string, pp, cycle):  # pylint: disable=unused-argument
     pp.text(string)
-  plain = get_ipython().display_formatter.formatters['text/plain']  # pylint: disable=undefined-variable
-  plain.for_type(str, _formatter)
+  if get_ipython():
+    plain = get_ipython().display_formatter.formatters['text/plain']  # pylint: disable=undefined-variable
+    plain.for_type(str, _formatter)
 
-# NameError is added here because get_ipython() throws "not defined" NameError
-# if not started with IPython.
-except (ImportError, NameError):
+except ImportError:
   IPython = None
   _display_progress = print
 
diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py
index 23605b5..aabf959 100644
--- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py
+++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py
@@ -44,12 +44,12 @@
 
     Examples:
       graph = pipeline_graph.PipelineGraph(pipeline_proto)
-      graph.display_graph()
+      graph.get_dot()
 
       or
 
       graph = pipeline_graph.PipelineGraph(pipeline)
-      graph.display_graph()
+      graph.get_dot()
 
     Args:
       pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
index 2b0dc7a..2dbc102 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
@@ -30,6 +30,7 @@
 
 import apache_beam as beam
 from apache_beam.runners import runner
+from apache_beam.runners.utils import is_interactive
 
 _interactive_beam_env = None
 
@@ -93,17 +94,7 @@
                       'install apache-beam[interactive]` to install necessary '
                       'dependencies to enable all data visualization features.')
 
-    self._is_in_ipython = False
-    self._is_in_notebook = False
-    # Check if the runtime is within an interactive environment, i.e., ipython.
-    try:
-      from IPython import get_ipython  # pylint: disable=import-error
-      if get_ipython():
-        self._is_in_ipython = True
-        if 'IPKernelApp' in get_ipython().config:
-          self._is_in_notebook = True
-    except ImportError:
-      pass
+    self._is_in_ipython, self._is_in_notebook = is_interactive()
     if not self._is_in_ipython:
       logging.warning('You cannot use Interactive Beam features when you are '
                       'not in an interactive environment such as a Jupyter '
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
index 342f400..6fa257b 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
@@ -126,7 +126,6 @@
 
   def test_determine_terminal_state(self):
     for state in (runner.PipelineState.DONE,
-                  runner.PipelineState.STOPPED,
                   runner.PipelineState.FAILED,
                   runner.PipelineState.CANCELLED,
                   runner.PipelineState.UPDATED,
@@ -136,7 +135,7 @@
       self.assertTrue(ie.current_env().is_terminated(self._p))
     for state in (runner.PipelineState.UNKNOWN,
                   runner.PipelineState.STARTING,
-
+                  runner.PipelineState.STOPPED,
                   runner.PipelineState.RUNNING,
                   runner.PipelineState.DRAINING,
                   runner.PipelineState.PENDING,
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner.py b/sdks/python/apache_beam/runners/interactive/interactive_runner.py
index 94c0de7..7b6df70 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner.py
@@ -48,7 +48,8 @@
                underlying_runner=None,
                cache_dir=None,
                cache_format='text',
-               render_option=None):
+               render_option=None,
+               skip_display=False):
     """Constructor of InteractiveRunner.
 
     Args:
@@ -58,12 +59,16 @@
           PCollection caches. Available options are 'text' and 'tfrecord'.
       render_option: (str) this parameter decides how the pipeline graph is
           rendered. See display.pipeline_graph_renderer for available options.
+      skip_display: (bool) whether to skip display operations when running the
+          pipeline. Useful if running large pipelines when display is not
+          needed.
     """
     self._underlying_runner = (underlying_runner
                                or direct_runner.DirectRunner())
     self._cache_manager = cache.FileBasedCacheManager(cache_dir, cache_format)
     self._renderer = pipeline_graph_renderer.get_renderer(render_option)
     self._in_session = False
+    self._skip_display = skip_display
 
   def is_fnapi_compatible(self):
     # TODO(BEAM-8436): return self._underlying_runner.is_fnapi_compatible()
@@ -140,15 +145,19 @@
         self._underlying_runner,
         options)
 
-    display = display_manager.DisplayManager(
-        pipeline_proto=pipeline_proto,
-        pipeline_analyzer=analyzer,
-        cache_manager=self._cache_manager,
-        pipeline_graph_renderer=self._renderer)
-    display.start_periodic_update()
+    if not self._skip_display:
+      display = display_manager.DisplayManager(
+          pipeline_proto=pipeline_proto,
+          pipeline_analyzer=analyzer,
+          cache_manager=self._cache_manager,
+          pipeline_graph_renderer=self._renderer)
+      display.start_periodic_update()
+
     result = pipeline_to_execute.run()
     result.wait_until_finish()
-    display.stop_periodic_update()
+
+    if not self._skip_display:
+      display.stop_periodic_update()
 
     return PipelineResult(result, self, self._analyzer.pipeline_info(),
                           self._cache_manager, pcolls_to_pcoll_id)
diff --git a/sdks/python/apache_beam/runners/portability/abstract_job_service.py b/sdks/python/apache_beam/runners/portability/abstract_job_service.py
index 982fad1..5dd497a 100644
--- a/sdks/python/apache_beam/runners/portability/abstract_job_service.py
+++ b/sdks/python/apache_beam/runners/portability/abstract_job_service.py
@@ -23,13 +23,6 @@
 from apache_beam.portability.api import beam_job_api_pb2
 from apache_beam.portability.api import beam_job_api_pb2_grpc
 
-TERMINAL_STATES = [
-    beam_job_api_pb2.JobState.DONE,
-    beam_job_api_pb2.JobState.STOPPED,
-    beam_job_api_pb2.JobState.FAILED,
-    beam_job_api_pb2.JobState.CANCELLED,
-]
-
 
 class AbstractJobServiceServicer(beam_job_api_pb2_grpc.JobServiceServicer):
   """Manages one or more pipelines, possibly concurrently.
@@ -131,6 +124,11 @@
   def get_pipeline(self):
     return self._pipeline_proto
 
+  @staticmethod
+  def is_terminal_state(state):
+    from apache_beam.runners.portability import portable_runner
+    return state in portable_runner.TERMINAL_STATES
+
   def to_runner_api(self):
     return beam_job_api_pb2.JobInfo(
         job_id=self._job_id,
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index da2ebdb..377ceb7 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -319,10 +319,10 @@
           line = f.readline()
       self.assertSetEqual(lines_actual, lines_expected)
 
-    def test_sdf_with_sdf_initiated_checkpointing(self):
+    def test_sdf_with_watermark_tracking(self):
       raise unittest.SkipTest("BEAM-2939")
 
-    def test_sdf_synthetic_source(self):
+    def test_sdf_with_sdf_initiated_checkpointing(self):
       raise unittest.SkipTest("BEAM-2939")
 
     def test_callbacks_with_exception(self):
diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
index b69da66..d2a890f 100644
--- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
+++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
@@ -216,7 +216,7 @@
         'IN_PROGRESS': beam_job_api_pb2.JobState.RUNNING,
         'COMPLETED': beam_job_api_pb2.JobState.DONE,
     }.get(flink_status, beam_job_api_pb2.JobState.UNSPECIFIED)
-    if beam_state in abstract_job_service.TERMINAL_STATES:
+    if self.is_terminal_state(beam_state):
       self.delete_jar()
     return beam_state
 
@@ -224,7 +224,7 @@
     sleep_secs = 1.0
     current_state = self.get_state()
     yield current_state
-    while current_state not in abstract_job_service.TERMINAL_STATES:
+    while not self.is_terminal_state(current_state):
       sleep_secs = min(60, sleep_secs * 1.2)
       time.sleep(sleep_secs)
       previous_state, current_state = current_state, self.get_state()
@@ -233,7 +233,7 @@
 
   def get_message_stream(self):
     for state in self.get_state_stream():
-      if state in abstract_job_service.TERMINAL_STATES:
+      if self.is_terminal_state(state):
         response = self.get('v1/jobs/%s/exceptions' % self._flink_job_id)
         for ix, exc in enumerate(response['all-exceptions']):
           yield beam_job_api_pb2.JobMessage(
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 2204a24..b7929cb 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -41,6 +41,7 @@
 from tenacity import stop_after_attempt
 
 import apache_beam as beam
+from apache_beam.io import iobase
 from apache_beam.io import restriction_trackers
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.execution import MetricKey
@@ -56,9 +57,11 @@
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.tools import utils
+from apache_beam.transforms import core
 from apache_beam.transforms import environments
 from apache_beam.transforms import userstate
 from apache_beam.transforms import window
+from apache_beam.utils import timestamp
 
 if statesampler.FAST_SAMPLER:
   DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS
@@ -423,17 +426,14 @@
       assert_that(actual, is_buffered_correctly)
 
   def test_sdf(self):
-
     class ExpandingStringsDoFn(beam.DoFn):
       def process(
           self,
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(
-            restriction_tracker,
-            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
-        cur = restriction_tracker.start_position()
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           yield element[cur]
           cur += 1
@@ -446,6 +446,56 @@
           | beam.ParDo(ExpandingStringsDoFn()))
       assert_that(actual, equal_to(list(''.join(data))))
 
+  def test_sdf_with_check_done_failed(self):
+    class ExpandingStringsDoFn(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider())):
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
+        while restriction_tracker.try_claim(cur):
+          yield element[cur]
+          cur += 1
+          return
+    with self.assertRaises(Exception):
+      with self.create_pipeline() as p:
+        data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+        _ = (
+            p
+            | beam.Create(data)
+            | beam.ParDo(ExpandingStringsDoFn()))
+
+  def test_sdf_with_watermark_tracking(self):
+
+    class ExpandingStringsDoFn(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider()),
+          watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
+              core.WatermarkEstimator())):
+        cur = restriction_tracker.current_restriction().start
+        start = cur
+        while restriction_tracker.try_claim(cur):
+          watermark_estimator.set_watermark(timestamp.Timestamp(micros=cur))
+          assert watermark_estimator.current_watermark().micros == start
+          yield element[cur]
+          if cur % 2 == 1:
+            restriction_tracker.defer_remainder(timestamp.Duration(micros=5))
+            return
+          cur += 1
+
+    with self.create_pipeline() as p:
+      data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+      actual = (
+          p
+          | beam.Create(data)
+          | beam.ParDo(ExpandingStringsDoFn()))
+      assert_that(actual, equal_to(list(''.join(data))))
+
   def test_sdf_with_sdf_initiated_checkpointing(self):
 
     counter = beam.metrics.Metrics.counter('ns', 'my_counter')
@@ -456,10 +506,8 @@
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(
-            restriction_tracker,
-            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
-        cur = restriction_tracker.start_position()
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           counter.inc()
           yield element[cur]
@@ -1123,6 +1171,9 @@
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
 
@@ -1142,6 +1193,9 @@
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
 
@@ -1172,6 +1226,9 @@
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerSplitTest(unittest.TestCase):
 
@@ -1340,7 +1397,7 @@
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
         to_emit = []
-        cur = restriction_tracker.start_position()
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           to_emit.append((element, cur))
           element_counter.increment()
diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py
index 4305810..6aad1af 100644
--- a/sdks/python/apache_beam/runners/portability/local_job_service.py
+++ b/sdks/python/apache_beam/runners/portability/local_job_service.py
@@ -195,7 +195,7 @@
     self._state = None
     self._state_queues = []
     self._log_queues = []
-    self.state = beam_job_api_pb2.JobState.STARTING
+    self.state = beam_job_api_pb2.JobState.STOPPED
     self.daemon = True
     self.result = None
 
@@ -220,10 +220,12 @@
     return self._artifact_staging_endpoint
 
   def run(self):
+    self.state = beam_job_api_pb2.JobState.STARTING
     self._run_thread = threading.Thread(target=self._run_job)
     self._run_thread.start()
 
   def _run_job(self):
+    self.state = beam_job_api_pb2.JobState.RUNNING
     with JobLogHandler(self._log_queues):
       try:
         result = fn_api_runner.FnApiRunner(
@@ -239,7 +241,7 @@
         raise
 
   def cancel(self):
-    if self.state not in abstract_job_service.TERMINAL_STATES:
+    if not self.is_terminal_state(self.state):
       self.state = beam_job_api_pb2.JobState.CANCELLING
       # TODO(robertwb): Actually cancel...
       self.state = beam_job_api_pb2.JobState.CANCELLED
@@ -253,7 +255,7 @@
     while True:
       current_state = state_queue.get(block=True)
       yield current_state
-      if current_state in abstract_job_service.TERMINAL_STATES:
+      if self.is_terminal_state(current_state):
         break
 
   def get_message_stream(self):
@@ -264,7 +266,7 @@
 
     current_state = self.state
     yield current_state
-    while current_state not in abstract_job_service.TERMINAL_STATES:
+    while not self.is_terminal_state(current_state):
       msg = log_queue.get(block=True)
       yield msg
       if isinstance(msg, int):
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 2cffd47..6de9e5c 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -58,7 +58,7 @@
 
 TERMINAL_STATES = [
     beam_job_api_pb2.JobState.DONE,
-    beam_job_api_pb2.JobState.STOPPED,
+    beam_job_api_pb2.JobState.DRAINED,
     beam_job_api_pb2.JobState.FAILED,
     beam_job_api_pb2.JobState.CANCELLED,
 ]
@@ -178,10 +178,9 @@
         del transform_proto.subtransforms[:]
 
     # Preemptively apply combiner lifting, until all runners support it.
-    # Also apply sdf expansion.
     # These optimizations commute and are idempotent.
     pre_optimize = options.view_as(DebugOptions).lookup_experiment(
-        'pre_optimize', 'lift_combiners,expand_sdf').lower()
+        'pre_optimize', 'lift_combiners').lower()
     if not options.view_as(StandardOptions).streaming:
       flink_known_urns = frozenset([
           common_urns.composites.RESHUFFLE.urn,
@@ -210,7 +209,7 @@
         phases = []
         for phase_name in pre_optimize.split(','):
           # For now, these are all we allow.
-          if phase_name in ('lift_combiners', 'expand_sdf'):
+          if phase_name in 'lift_combiners':
             phases.append(getattr(fn_api_runner_transforms, phase_name))
           else:
             raise ValueError(
@@ -274,10 +273,12 @@
                  for k, v in all_options.items()
                  if v is not None}
 
+    prepare_request = beam_job_api_pb2.PrepareJobRequest(
+        job_name='job', pipeline=proto_pipeline,
+        pipeline_options=job_utils.dict_to_struct(p_options))
+    logging.debug('PrepareJobRequest: %s', prepare_request)
     prepare_response = job_service.Prepare(
-        beam_job_api_pb2.PrepareJobRequest(
-            job_name='job', pipeline=proto_pipeline,
-            pipeline_options=job_utils.dict_to_struct(p_options)),
+        prepare_request,
         timeout=portable_options.job_server_timeout)
     if prepare_response.artifact_staging_endpoint.url:
       stager = portable_stager.PortableStager(
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 24c6b87..992aa08 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -45,7 +45,10 @@
 from apache_beam.runners.portability.portable_runner import PortableRunner
 from apache_beam.runners.worker import worker_pool_main
 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
 from apache_beam.transforms import environments
+from apache_beam.transforms import userstate
 
 
 class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest):
@@ -188,6 +191,46 @@
   def test_metrics(self):
     self.skipTest('Metrics not supported.')
 
+  def test_pardo_state_with_custom_key_coder(self):
+    """Tests that state requests work correctly when the key coder is an
+    SDK-specific coder, i.e. non standard coder. This is additionally enforced
+    by Java's ProcessBundleDescriptorsTest and by Flink's
+    ExecutableStageDoFnOperator which detects invalid encoding by checking for
+    the correct key group of the encoded key."""
+    index_state_spec = userstate.CombiningValueStateSpec('index', sum)
+
+    # Test params
+    # Ensure decent amount of elements to serve all partitions
+    n = 200
+    duplicates = 1
+
+    split = n // (duplicates + 1)
+    inputs = [(i % split, str(i % split)) for i in range(0, n)]
+
+    # Use a DoFn which has to use FastPrimitivesCoder because the type cannot
+    # be inferred
+    class Input(beam.DoFn):
+      def process(self, impulse):
+        for i in inputs:
+          yield i
+
+    class AddIndex(beam.DoFn):
+      def process(self, kv,
+                  index=beam.DoFn.StateParam(index_state_spec)):
+        k, v = kv
+        index.add(1)
+        yield k, v, index.read()
+
+    expected = [(i % split, str(i % split), i // split + 1)
+                for i in range(0, n)]
+
+    with self.create_pipeline() as p:
+      assert_that(p
+                  | beam.Impulse()
+                  | beam.ParDo(Input())
+                  | beam.ParDo(AddIndex()),
+                  equal_to(expected))
+
   # Inherits all other tests from fn_api_runner_test.FnApiRunnerTest
 
 
diff --git a/sdks/python/apache_beam/runners/portability/spark_runner.py b/sdks/python/apache_beam/runners/portability/spark_runner.py
new file mode 100644
index 0000000..ca03310
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/spark_runner.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A runner for executing portable pipelines on Spark."""
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+import re
+
+from apache_beam.options import pipeline_options
+from apache_beam.runners.portability import job_server
+from apache_beam.runners.portability import portable_runner
+
+# https://spark.apache.org/docs/latest/submitting-applications.html#master-urls
+LOCAL_MASTER_PATTERN = r'^local(\[.+\])?$'
+
+
+class SparkRunner(portable_runner.PortableRunner):
+  def run_pipeline(self, pipeline, options):
+    spark_options = options.view_as(SparkRunnerOptions)
+    portable_options = options.view_as(pipeline_options.PortableOptions)
+    if (re.match(LOCAL_MASTER_PATTERN, spark_options.spark_master_url)
+        and not portable_options.environment_type
+        and not portable_options.output_executable_path):
+      portable_options.environment_type = 'LOOPBACK'
+    return super(SparkRunner, self).run_pipeline(pipeline, options)
+
+  def default_job_server(self, options):
+    # TODO(BEAM-8139) submit a Spark jar to a cluster
+    return job_server.StopOnExitJobServer(SparkJarJobServer(options))
+
+
+class SparkRunnerOptions(pipeline_options.PipelineOptions):
+  @classmethod
+  def _add_argparse_args(cls, parser):
+    parser.add_argument('--spark_master_url',
+                        default='local[4]',
+                        help='Spark master URL (spark://HOST:PORT). '
+                             'Use "local" (single-threaded) or "local[*]" '
+                             '(multi-threaded) to start a local cluster for '
+                             'the execution.')
+    parser.add_argument('--spark_job_server_jar',
+                        help='Path or URL to a Beam Spark jobserver jar.')
+    parser.add_argument('--artifacts_dir', default=None)
+
+
+class SparkJarJobServer(job_server.JavaJarJobServer):
+  def __init__(self, options):
+    super(SparkJarJobServer, self).__init__()
+    options = options.view_as(SparkRunnerOptions)
+    self._jar = options.spark_job_server_jar
+    self._master_url = options.spark_master_url
+    self._artifacts_dir = options.artifacts_dir
+
+  def path_to_jar(self):
+    if self._jar:
+      return self._jar
+    else:
+      return self.path_to_beam_jar('runners:spark:job-server:shadowJar')
+
+  def java_arguments(self, job_port, artifacts_dir):
+    return [
+        '--spark-master-url', self._master_url,
+        '--artifacts-dir', (self._artifacts_dir
+                            if self._artifacts_dir else artifacts_dir),
+        '--job-port', job_port,
+        '--artifact-port', 0,
+        '--expansion-port', 0
+    ]
diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py
index 1000480..fe9c492 100644
--- a/sdks/python/apache_beam/runners/runner.py
+++ b/sdks/python/apache_beam/runners/runner.py
@@ -38,6 +38,7 @@
     'apache_beam.runners.interactive.interactive_runner.InteractiveRunner',
     'apache_beam.runners.portability.flink_runner.FlinkRunner',
     'apache_beam.runners.portability.portable_runner.PortableRunner',
+    'apache_beam.runners.portability.spark_runner.SparkRunner',
     'apache_beam.runners.test.TestDirectRunner',
     'apache_beam.runners.test.TestDataflowRunner',
 )
@@ -328,7 +329,7 @@
 
   @classmethod
   def is_terminal(cls, state):
-    return state in [cls.STOPPED, cls.DONE, cls.FAILED, cls.CANCELLED,
+    return state in [cls.DONE, cls.FAILED, cls.CANCELLED,
                      cls.UPDATED, cls.DRAINED]
 
 
diff --git a/sdks/python/apache_beam/runners/utils.py b/sdks/python/apache_beam/runners/utils.py
new file mode 100644
index 0000000..8952423
--- /dev/null
+++ b/sdks/python/apache_beam/runners/utils.py
@@ -0,0 +1,47 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Common utility module shared by runners.
+
+For internal use only; no backwards-compatibility guarantees.
+"""
+from __future__ import absolute_import
+
+
+def is_interactive():
+  """Determines if current code execution is in interactive environment.
+
+  Returns:
+    is_in_ipython: (bool) tells if current code is executed within an ipython
+        session.
+    is_in_notebook: (bool) tells if current code is executed from an ipython
+        notebook.
+
+  If is_in_notebook is True, then is_in_ipython must also be True.
+  """
+  is_in_ipython = False
+  is_in_notebook = False
+  # Check if the runtime is within an interactive environment, i.e., ipython.
+  try:
+    from IPython import get_ipython  # pylint: disable=import-error
+    if get_ipython():
+      is_in_ipython = True
+      if 'IPKernelApp' in get_ipython().config:
+        is_in_notebook = True
+  except ImportError:
+    pass  # If dependencies are not available, then not interactive for sure.
+  return is_in_ipython, is_in_notebook
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 8439c8f..b3440df 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -32,6 +32,7 @@
 from builtins import object
 
 from future.utils import itervalues
+from google.protobuf import duration_pb2
 from google.protobuf import timestamp_pb2
 
 import apache_beam as beam
@@ -704,8 +705,7 @@
               ) = split
               if element_primary:
                 split_response.primary_roots.add().CopyFrom(
-                    self.delayed_bundle_application(
-                        *element_primary).application)
+                    self.bundle_application(*element_primary))
               if element_residual:
                 split_response.residual_roots.add().CopyFrom(
                     self.delayed_bundle_application(*element_residual))
@@ -718,22 +718,39 @@
     return split_response
 
   def delayed_bundle_application(self, op, deferred_remainder):
-    transform_id, main_input_tag, main_input_coder, outputs = op.input_info
     # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
-    element_and_restriction, watermark = deferred_remainder
-    if watermark:
-      proto_watermark = timestamp_pb2.Timestamp()
-      proto_watermark.FromMicroseconds(watermark.micros)
-      output_watermarks = {output: proto_watermark for output in outputs}
+    ((element_and_restriction, output_watermark),
+     deferred_watermark) = deferred_remainder
+    if deferred_watermark:
+      assert isinstance(deferred_watermark, timestamp.Duration)
+      proto_deferred_watermark = duration_pb2.Duration()
+      proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros)
+    else:
+      proto_deferred_watermark = None
+    return beam_fn_api_pb2.DelayedBundleApplication(
+        requested_time_delay=proto_deferred_watermark,
+        application=self.construct_bundle_application(
+            op, output_watermark, element_and_restriction))
+
+  def bundle_application(self, op, primary):
+    ((element_and_restriction, output_watermark),
+     _) = primary
+    return self.construct_bundle_application(
+        op, output_watermark, element_and_restriction)
+
+  def construct_bundle_application(self, op, output_watermark, element):
+    transform_id, main_input_tag, main_input_coder, outputs = op.input_info
+    if output_watermark:
+      proto_output_watermark = timestamp_pb2.Timestamp()
+      proto_output_watermark.FromMicroseconds(output_watermark.micros)
+      output_watermarks = {output: proto_output_watermark for output in outputs}
     else:
       output_watermarks = None
-    return beam_fn_api_pb2.DelayedBundleApplication(
-        application=beam_fn_api_pb2.BundleApplication(
-            transform_id=transform_id,
-            input_id=main_input_tag,
-            output_watermarks=output_watermarks,
-            element=main_input_coder.get_impl().encode_nested(
-                element_and_restriction)))
+    return beam_fn_api_pb2.BundleApplication(
+        transform_id=transform_id,
+        input_id=main_input_tag,
+        output_watermarks=output_watermarks,
+        element=main_input_coder.get_impl().encode_nested(element))
 
   def metrics(self):
     # DEPRECATED
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 208fe75..74a3e99 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -182,17 +182,7 @@
     self._responses.put(response)
 
   def _request_register(self, request):
-
-    def task():
-      for process_bundle_descriptor in getattr(
-          request, request.WhichOneof('request')).process_bundle_descriptor:
-        self._fns[process_bundle_descriptor.id] = process_bundle_descriptor
-
-      return beam_fn_api_pb2.InstructionResponse(
-          instruction_id=request.instruction_id,
-          register=beam_fn_api_pb2.RegisterResponse())
-
-    self._execute(task, request)
+    self._request_execute(request)
 
   def _request_process_bundle(self, request):
 
@@ -241,6 +231,9 @@
     self._progress_thread_pool.submit(task)
 
   def _request_finalize_bundle(self, request):
+    self._request_execute(request)
+
+  def _request_execute(self, request):
 
     def task():
       # Get one available worker.
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
index 799bd0d..aebf9f6 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
@@ -43,6 +43,9 @@
 
   cdef int32_t current_state_index
 
+  cpdef ScopedState current_state(self)
+  cdef inline ScopedState current_state_c(self)
+
   cpdef _scoped_state(
       self, counter_name, name_context, output_counter, metrics_container)
 
@@ -56,7 +59,7 @@
   cdef readonly object name_context
   cdef readonly int64_t _nsecs
   cdef int32_t old_state_index
-  cdef readonly MetricsContainer _metrics_container
+  cdef readonly MetricsContainer metrics_container
 
   cpdef __enter__(self)
 
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 325ec99..8d2346a 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -159,8 +159,12 @@
       (<ScopedState>state)._nsecs = 0
     self.started = self.finished = False
 
-  def current_state(self):
-    return self.scoped_states_by_index[self.current_state_index]
+  cpdef ScopedState current_state(self):
+    return self.current_state_c()
+
+  cdef inline ScopedState current_state_c(self):
+    # Faster than cpdef due to self always being a Python subclass.
+    return <ScopedState>self.scoped_states_by_index[self.current_state_index]
 
   cpdef _scoped_state(self, counter_name, name_context, output_counter,
                       metrics_container):
@@ -189,6 +193,11 @@
     pythread.PyThread_release_lock(self.lock)
     return scoped_state
 
+  def update_metric(self, typed_metric_name, value):
+    # Each of these is a cdef lookup.
+    self.current_state_c().metrics_container.get_metric_cell(
+        typed_metric_name).update(value)
+
 
 cdef class ScopedState(object):
   """Context manager class managing transitions for a given sampler state."""
@@ -205,7 +214,7 @@
     self.name_context = step_name_context
     self.state_index = state_index
     self.counter = counter
-    self._metrics_container = metrics_container
+    self.metrics_container = metrics_container
 
   @property
   def nsecs(self):
@@ -232,7 +241,3 @@
     self.sampler.current_state_index = self.old_state_index
     self.sampler.state_transition_count += 1
     pythread.PyThread_release_lock(self.sampler.lock)
-
-  @property
-  def metrics_container(self):
-    return self._metrics_container
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
index 0091828..fb2592c 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
@@ -50,6 +50,10 @@
     return ScopedState(
         self, counter_name, name_context, output_counter, metrics_container)
 
+  def update_metric(self, typed_metric_name, value):
+    self.current_state().metrics_container.get_metric_cell(
+        typed_metric_name).update(value)
+
   def _enter_state(self, state):
     self.state_transition_count += 1
     self._state_stack.append(state)
diff --git a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
index 0e01ad9..cac0c74 100644
--- a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
+++ b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
@@ -139,6 +139,32 @@
            final: false, index: 3, nonspeculative_index: 1}
 
 ---
+name: discarding_early_fixed
+window_fn: FixedWindows(10)
+trigger_fn: AfterWatermark(early=AfterCount(2))
+timestamp_combiner: OUTPUT_AT_EOW
+accumulation_mode: discarding
+transcript:
+- input: [1, 2, 3]
+- expect:
+  - {window: [0, 9], values: [1, 2, 3], timestamp: 9, early: true, index: 0}
+- input: [4]    # no output
+- input: [14]   # no output
+- input: [5]
+- expect:
+  - {window: [0, 9], values: [4, 5], timestamp: 9, early: true, index: 1}
+- input: [18]
+- expect:
+  - {window: [10, 19], values: [14, 18], timestamp: 19, early: true, index: 0}
+- input: [6]
+- watermark: 100
+- expect:
+  - {window: [0, 9], values:[6], timestamp: 9, early: false, late: false,
+     final: true, index: 2, nonspeculative_index: 0}
+  - {window: [10, 19], values:[], timestamp: 19, early: false, late: false,
+     final: true, index: 1, nonspeculative_index: 0}
+
+---
 name: garbage_collection
 broken_on:
   - SwitchingDirectRunner  # claims pipeline stall
diff --git a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py
index 16ffee9..ec17ef6 100644
--- a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py
+++ b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py
@@ -66,12 +66,11 @@
 
   def test_pipeline_state_matcher_fails(self):
     """Test PipelineStateMatcher fails when using default expected state
-    and job actually finished in CANCELLED/DRAINED/FAILED/STOPPED/UNKNOWN
+    and job actually finished in CANCELLED/DRAINED/FAILED/UNKNOWN
     """
     failed_state = [PipelineState.CANCELLED,
                     PipelineState.DRAINED,
                     PipelineState.FAILED,
-                    PipelineState.STOPPED,
                     PipelineState.UNKNOWN]
 
     for state in failed_state:
diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py
index 50740ba..fbef112 100644
--- a/sdks/python/apache_beam/testing/synthetic_pipeline.py
+++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py
@@ -523,7 +523,7 @@
       element,
       restriction_tracker=beam.DoFn.RestrictionParam(
           SyntheticSDFSourceRestrictionProvider())):
-    cur = restriction_tracker.start_position()
+    cur = restriction_tracker.current_restriction().start
     while restriction_tracker.try_claim(cur):
       r = np.random.RandomState(cur)
       time.sleep(element['sleep_per_input_record_sec'])
diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py
index 02a8607..9d9284c 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -31,6 +31,8 @@
 from apache_beam import coders
 from apache_beam import core
 from apache_beam import pvalue
+from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.transforms import PTransform
 from apache_beam.transforms import window
 from apache_beam.transforms.window import TimestampedValue
@@ -66,6 +68,28 @@
     # TODO(BEAM-5949): Needed for Python 2 compatibility.
     return not self == other
 
+  @abstractmethod
+  def to_runner_api(self, element_coder):
+    raise NotImplementedError
+
+  @staticmethod
+  def from_runner_api(proto, element_coder):
+    if proto.HasField('element_event'):
+      return ElementEvent(
+          [TimestampedValue(
+              element_coder.decode(tv.encoded_element),
+              timestamp.Timestamp(micros=1000 * tv.timestamp))
+           for tv in proto.element_event.elements])
+    elif proto.HasField('watermark_event'):
+      return WatermarkEvent(timestamp.Timestamp(
+          micros=1000 * proto.watermark_event.new_watermark))
+    elif proto.HasField('processing_time_event'):
+      return ProcessingTimeEvent(timestamp.Duration(
+          micros=1000 * proto.processing_time_event.advance_duration))
+    else:
+      raise ValueError(
+          'Unknown TestStream Event type: %s' % proto.WhichOneof('event'))
+
 
 class ElementEvent(Event):
   """Element-producing test stream event."""
@@ -82,6 +106,15 @@
   def __lt__(self, other):
     return self.timestamped_values < other.timestamped_values
 
+  def to_runner_api(self, element_coder):
+    return beam_runner_api_pb2.TestStreamPayload.Event(
+        element_event=beam_runner_api_pb2.TestStreamPayload.Event.AddElements(
+            elements=[
+                beam_runner_api_pb2.TestStreamPayload.TimestampedElement(
+                    encoded_element=element_coder.encode(tv.value),
+                    timestamp=tv.timestamp.micros // 1000)
+                for tv in self.timestamped_values]))
+
 
 class WatermarkEvent(Event):
   """Watermark-advancing test stream event."""
@@ -98,6 +131,11 @@
   def __lt__(self, other):
     return self.new_watermark < other.new_watermark
 
+  def to_runner_api(self, unused_element_coder):
+    return beam_runner_api_pb2.TestStreamPayload.Event(
+        watermark_event
+        =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceWatermark(
+            new_watermark=self.new_watermark.micros // 1000))
 
 class ProcessingTimeEvent(Event):
   """Processing time-advancing test stream event."""
@@ -114,6 +152,12 @@
   def __lt__(self, other):
     return self.advance_by < other.advance_by
 
+  def to_runner_api(self, unused_element_coder):
+    return beam_runner_api_pb2.TestStreamPayload.Event(
+        processing_time_event
+        =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceProcessingTime(
+            advance_duration=self.advance_by.micros // 1000))
+
 
 class TestStream(PTransform):
   """Test stream that generates events on an unbounded PCollection of elements.
@@ -123,11 +167,12 @@
   output.
   """
 
-  def __init__(self, coder=coders.FastPrimitivesCoder):
+  def __init__(self, coder=coders.FastPrimitivesCoder(), events=()):
+    super(TestStream, self).__init__()
     assert coder is not None
     self.coder = coder
     self.current_watermark = timestamp.MIN_TIMESTAMP
-    self.events = []
+    self.events = list(events)
 
   def get_windowing(self, unused_inputs):
     return core.Windowing(window.GlobalWindows())
@@ -206,3 +251,19 @@
     """
     self._add(ProcessingTimeEvent(advance_by))
     return self
+
+  def to_runner_api_parameter(self, context):
+    return (
+        common_urns.primitives.TEST_STREAM.urn,
+        beam_runner_api_pb2.TestStreamPayload(
+            coder_id=context.coders.get_id(self.coder),
+            events=[e.to_runner_api(self.coder) for e in self.events]))
+
+  @PTransform.register_urn(
+      common_urns.primitives.TEST_STREAM.urn,
+      beam_runner_api_pb2.TestStreamPayload)
+  def from_runner_api_parameter(payload, context):
+    coder = context.coders.get_by_id(payload.coder_id)
+    return TestStream(
+        coder=coder,
+        events=[Event.from_runner_api(e, coder) for e in payload.events])
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py
index 3297f63..c8bc9ff 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -101,9 +101,11 @@
                    .advance_processing_time(10)
                    .advance_watermark_to(300)
                    .add_elements([TimestampedValue('late', 12)])
-                   .add_elements([TimestampedValue('last', 310)]))
+                   .add_elements([TimestampedValue('last', 310)])
+                   .advance_watermark_to_infinity())
 
     class RecordFn(beam.DoFn):
+
       def process(self, element=beam.DoFn.ElementParam,
                   timestamp=beam.DoFn.TimestampParam):
         yield (element, timestamp)
@@ -135,7 +137,8 @@
                    .advance_processing_time(10)
                    .advance_watermark_to(300)
                    .add_elements([TimestampedValue('late', 12)])
-                   .add_elements([TimestampedValue('last', 310)]))
+                   .add_elements([TimestampedValue('last', 310)])
+                   .advance_watermark_to_infinity())
 
     options = PipelineOptions()
     options.view_as(StandardOptions).streaming = True
@@ -175,7 +178,8 @@
     test_stream = (TestStream()
                    .advance_watermark_to(10)
                    .add_elements(['a'])
-                   .advance_watermark_to(20))
+                   .advance_watermark_to(20)
+                   .advance_watermark_to_infinity())
 
     options = PipelineOptions()
     options.view_as(StandardOptions).streaming = True
@@ -217,7 +221,8 @@
     test_stream = (TestStream()
                    .advance_watermark_to(10)
                    .add_elements(['a'])
-                   .advance_processing_time(5.1))
+                   .advance_processing_time(5.1)
+                   .advance_watermark_to_infinity())
 
     options = PipelineOptions()
     options.view_as(StandardOptions).streaming = True
@@ -255,12 +260,14 @@
     main_stream = (p
                    | 'main TestStream' >> TestStream()
                    .advance_watermark_to(10)
-                   .add_elements(['e']))
+                   .add_elements(['e'])
+                   .advance_watermark_to_infinity())
     side = (p
             | beam.Create([2, 1, 4])
             | beam.Map(lambda t: window.TimestampedValue(t, t)))
 
     class RecordFn(beam.DoFn):
+
       def process(self,
                   elm=beam.DoFn.ElementParam,
                   ts=beam.DoFn.TimestampParam,
@@ -316,6 +323,7 @@
                    .add_elements(['a'])
                    .advance_watermark_to(4)
                    .add_elements(['b'])
+                   .advance_watermark_to_infinity()
                    | 'main window' >> beam.WindowInto(window.FixedWindows(1)))
     side = (p
             | beam.Create([2, 1, 4])
@@ -323,6 +331,7 @@
             | beam.WindowInto(window.FixedWindows(2)))
 
     class RecordFn(beam.DoFn):
+
       def process(self,
                   elm=beam.DoFn.ElementParam,
                   ts=beam.DoFn.TimestampParam,
@@ -334,8 +343,8 @@
 
     # assert per window
     expected_window_to_elements = {
-        window.IntervalWindow(2, 3):[('a', Timestamp(2), [2])],
-        window.IntervalWindow(4, 5):[('b', Timestamp(4), [4])]
+        window.IntervalWindow(2, 3): [('a', Timestamp(2), [2])],
+        window.IntervalWindow(4, 5): [('b', Timestamp(4), [4])]
     }
     assert_that(
         records,
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 148caae..06fd201 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -63,6 +63,7 @@
 from apache_beam.typehints.decorators import get_type_hints
 from apache_beam.typehints.trivial_inference import element_type
 from apache_beam.typehints.typehints import is_consistent_with
+from apache_beam.utils import timestamp
 from apache_beam.utils import urns
 
 try:
@@ -91,7 +92,8 @@
     'Flatten',
     'Create',
     'Impulse',
-    'RestrictionProvider'
+    'RestrictionProvider',
+    'WatermarkEstimator'
     ]
 
 # Type variables
@@ -242,6 +244,8 @@
   def create_tracker(self, restriction):
     """Produces a new ``RestrictionTracker`` for the given restriction.
 
+    This API is required to be implemented.
+
     Args:
       restriction: an object that defines a restriction as identified by a
         Splittable ``DoFn`` that utilizes the current ``RestrictionProvider``.
@@ -252,7 +256,10 @@
     raise NotImplementedError
 
   def initial_restriction(self, element):
-    """Produces an initial restriction for the given element."""
+    """Produces an initial restriction for the given element.
+
+    This API is required to be implemented.
+    """
     raise NotImplementedError
 
   def split(self, element, restriction):
@@ -262,6 +269,9 @@
     reading input element for each of the returned restrictions should be the
     same as the total set of elements produced by reading the input element for
     the input restriction.
+
+    This API is optional if ``split_and_size`` has been implemented.
+
     """
     yield restriction
 
@@ -281,11 +291,16 @@
 
     By default, asks a newly-created restriction tracker for the default size
     of the restriction.
+
+    This API is required to be implemented.
     """
-    return self.create_tracker(restriction).default_size()
+    raise NotImplementedError
 
   def split_and_size(self, element, restriction):
     """Like split, but also does sizing, returning (restriction, size) pairs.
+
+    This API is optional if ``split`` and ``restriction_size`` have been
+    implemented.
     """
     for part in self.split(element, restriction):
       yield part, self.restriction_size(element, part)
@@ -379,6 +394,43 @@
     return None
 
 
+class WatermarkEstimator(object):
+  """A WatermarkEstimator which is used for tracking output_watermark in a
+  DoFn.process(), typically tracking per <element, restriction> pair in SDF in
+  streaming.
+
+  There are 3 APIs in this class: set_watermark, current_watermark and reset
+  with default implementations.
+
+  TODO(BEAM-8537): Create WatermarkEstimatorProvider to support different types.
+  """
+  def __init__(self):
+    self._watermark = None
+
+  def set_watermark(self, watermark):
+    """Update tracking output_watermark with latest output_watermark.
+    This function is called inside an SDF.Process() to track the watermark of
+    output element.
+
+    Args:
+      watermark: the `timestamp.Timestamp` of current output element.
+    """
+    if not isinstance(watermark, timestamp.Timestamp):
+      raise ValueError('watermark should be a object of timestamp.Timestamp')
+    if self._watermark is None:
+      self._watermark = watermark
+    else:
+      self._watermark = min(self._watermark, watermark)
+
+  def current_watermark(self):
+    """Get current output_watermark. This function is called by system."""
+    return self._watermark
+
+  def reset(self):
+    """ Reset current tracking watermark to None."""
+    self._watermark = None
+
+
 class _DoFnParam(object):
   """DoFn parameter."""
 
@@ -459,6 +511,17 @@
     del self._callbacks[:]
 
 
+class _WatermarkEstimatorParam(_DoFnParam):
+  """WatermarkEstomator DoFn parameter."""
+
+  def __init__(self, watermark_estimator):
+    if not isinstance(watermark_estimator, WatermarkEstimator):
+      raise ValueError('DoFn.WatermarkEstimatorParam expected'
+                       'WatermarkEstimator object.')
+    self.watermark_estimator = watermark_estimator
+    self.param_id = 'WatermarkEstimator'
+
+
 class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   """A function object used by a transform with custom processing.
 
@@ -477,7 +540,7 @@
   TimestampParam = _DoFnParam('TimestampParam')
   WindowParam = _DoFnParam('WindowParam')
   PaneInfoParam = _DoFnParam('PaneInfoParam')
-  WatermarkReporterParam = _DoFnParam('WatermarkReporterParam')
+  WatermarkEstimatorParam = _WatermarkEstimatorParam
   BundleFinalizerParam = _BundleFinalizerParam
   KeyParam = _DoFnParam('KeyParam')
 
@@ -489,7 +552,7 @@
   TimerParam = _TimerDoFnParam
 
   DoFnProcessParams = [ElementParam, SideInputParam, TimestampParam,
-                       WindowParam, WatermarkReporterParam, PaneInfoParam,
+                       WindowParam, WatermarkEstimatorParam, PaneInfoParam,
                        BundleFinalizerParam, KeyParam, StateParam, TimerParam]
 
   RestrictionParam = _RestrictionDoFnParam
@@ -522,7 +585,7 @@
     ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
     provided here to allow treatment as a Splittable ``DoFn``. The restriction
     tracker will be derived from the restriction provider in the parameter.
-    ``DoFn.WatermarkReporterParam``: a function that can be used to report
+    ``DoFn.WatermarkEstimatorParam``: a function that can be used to track
     output watermark of Splittable ``DoFn`` implementations.
 
     Args:
diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py
new file mode 100644
index 0000000..1a27bd2
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for core module."""
+
+from __future__ import absolute_import
+
+import unittest
+
+from apache_beam.transforms.core import WatermarkEstimator
+from apache_beam.utils.timestamp import Timestamp
+
+
+class WatermarkEstimatorTest(unittest.TestCase):
+
+  def test_set_watermark(self):
+    watermark_estimator = WatermarkEstimator()
+    self.assertEqual(watermark_estimator.current_watermark(), None)
+    # set_watermark should only accept timestamp.Timestamp.
+    with self.assertRaises(ValueError):
+      watermark_estimator.set_watermark(0)
+
+    # watermark_estimator should always keep minimal timestamp.
+    watermark_estimator.set_watermark(Timestamp(100))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.set_watermark(Timestamp(150))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.set_watermark(Timestamp(50))
+    self.assertEqual(watermark_estimator.current_watermark(), 50)
+
+  def test_reset(self):
+    watermark_estimator = WatermarkEstimator()
+    watermark_estimator.set_watermark(Timestamp(100))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.reset()
+    self.assertEqual(watermark_estimator.current_watermark(), None)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index dbc4bcd..22ecda3 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -20,6 +20,7 @@
 from __future__ import absolute_import
 
 import collections
+import json
 import os.path
 import pickle
 import unittest
@@ -31,6 +32,7 @@
 import yaml
 
 import apache_beam as beam
+from apache_beam import coders
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.runners import pipeline_context
@@ -502,7 +504,10 @@
     while hasattr(cls, unique_name):
       counter += 1
       unique_name = 'test_%s_%d' % (name, counter)
-    setattr(cls, unique_name, lambda self: self._run_log_test(spec))
+    test_method = lambda self: self._run_log_test(spec)
+    test_method.__name__ = unique_name
+    test_method.__test__ = True
+    setattr(cls, unique_name, test_method)
 
   # We must prepend an underscore to this name so that the open-source unittest
   # runner does not execute this method directly as a test.
@@ -606,24 +611,25 @@
         window_fn, trigger_fn, accumulation_mode, timestamp_combiner,
         transcript, spec)
 
-  def _windowed_value_info(self, windowed_value):
-    # Currently some runners operate at the millisecond level, and some at the
-    # microsecond level.  Trigger transcript timestamps are expressed as
-    # integral units of the finest granularity, whatever that may be.
-    # In these tests we interpret them as integral seconds and then truncate
-    # the results to integral seconds to allow for portability across
-    # different sub-second resolutions.
-    window, = windowed_value.windows
-    return {
-        'window': [int(window.start), int(window.max_timestamp())],
-        'values': sorted(windowed_value.value),
-        'timestamp': int(windowed_value.timestamp),
-        'index': windowed_value.pane_info.index,
-        'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
-        'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
-        'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
-        'final': windowed_value.pane_info.is_last,
-    }
+
+def _windowed_value_info(windowed_value):
+  # Currently some runners operate at the millisecond level, and some at the
+  # microsecond level.  Trigger transcript timestamps are expressed as
+  # integral units of the finest granularity, whatever that may be.
+  # In these tests we interpret them as integral seconds and then truncate
+  # the results to integral seconds to allow for portability across
+  # different sub-second resolutions.
+  window, = windowed_value.windows
+  return {
+      'window': [int(window.start), int(window.max_timestamp())],
+      'values': sorted(windowed_value.value),
+      'timestamp': int(windowed_value.timestamp),
+      'index': windowed_value.pane_info.index,
+      'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
+      'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
+      'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
+      'final': windowed_value.pane_info.is_last,
+  }
 
 
 class TriggerDriverTranscriptTest(TranscriptTest):
@@ -645,7 +651,7 @@
         for timer_window, (name, time_domain, t_timestamp) in to_fire:
           for wvalue in driver.process_timer(
               timer_window, name, time_domain, t_timestamp, state):
-            output.append(self._windowed_value_info(wvalue))
+            output.append(_windowed_value_info(wvalue))
         to_fire = state.get_and_clear_timers(watermark)
 
     for action, params in transcript:
@@ -661,7 +667,7 @@
             WindowedValue(t, t, window_fn.assign(WindowFn.AssignContext(t, t)))
             for t in params]
         output = [
-            self._windowed_value_info(wv)
+            _windowed_value_info(wv)
             for wv in driver.process_elements(state, bundle, watermark)]
         fire_timers()
 
@@ -690,7 +696,7 @@
     self.assertEqual([], output, msg='Unexpected output: %s' % output)
 
 
-class TestStreamTranscriptTest(TranscriptTest):
+class BaseTestStreamTranscriptTest(TranscriptTest):
   """A suite of TestStream-based tests based on trigger transcript entries.
   """
 
@@ -702,14 +708,17 @@
     if runner_name in spec.get('broken_on', ()):
       self.skipTest('Known to be broken on %s' % runner_name)
 
-    test_stream = TestStream()
+    # Elements are encoded as a json strings to allow other languages to
+    # decode elements while executing the test stream.
+    # TODO(BEAM-8600): Eliminate these gymnastics.
+    test_stream = TestStream(coder=coders.StrUtf8Coder()).with_output_types(str)
     for action, params in transcript:
       if action == 'expect':
-        test_stream.add_elements([('expect', params)])
+        test_stream.add_elements([json.dumps(('expect', params))])
       else:
-        test_stream.add_elements([('expect', [])])
+        test_stream.add_elements([json.dumps(('expect', []))])
         if action == 'input':
-          test_stream.add_elements([('input', e) for e in params])
+          test_stream.add_elements([json.dumps(('input', e)) for e in params])
         elif action == 'watermark':
           test_stream.advance_watermark_to(params)
         elif action == 'clock':
@@ -718,7 +727,9 @@
           pass  # Requires inspection of implementation details.
         else:
           raise ValueError('Unexpected action: %s' % action)
-    test_stream.add_elements([('expect', [])])
+    test_stream.add_elements([json.dumps(('expect', []))])
+
+    read_test_stream = test_stream | beam.Map(json.loads)
 
     class Check(beam.DoFn):
       """A StatefulDoFn that verifies outputs are produced as expected.
@@ -731,12 +742,40 @@
 
       The key is ignored, but all items must be on the same key to share state.
       """
+      def __init__(self, allow_out_of_order=True):
+        # Some runners don't support cross-stage TestStream semantics.
+        self.allow_out_of_order = allow_out_of_order
+
       def process(
-          self, element, seen=beam.DoFn.StateParam(
+          self,
+          element,
+          seen=beam.DoFn.StateParam(
               beam.transforms.userstate.BagStateSpec(
                   'seen',
+                  beam.coders.FastPrimitivesCoder())),
+          expected=beam.DoFn.StateParam(
+              beam.transforms.userstate.BagStateSpec(
+                  'expected',
                   beam.coders.FastPrimitivesCoder()))):
         _, (action, data) = element
+
+        if self.allow_out_of_order:
+          if action == 'expect' and not list(seen.read()):
+            if data:
+              expected.add(data)
+            return
+          elif action == 'actual' and list(expected.read()):
+            seen.add(data)
+            all_data = list(seen.read())
+            all_expected = list(expected.read())
+            if len(all_data) == len(all_expected[0]):
+              expected.clear()
+              for expect in all_expected[1:]:
+                expected.add(expect)
+              action, data = 'expect', all_expected[0]
+            else:
+              return
+
         if action == 'actual':
           seen.add(data)
 
@@ -768,12 +807,14 @@
         else:
           raise ValueError('Unexpected action: %s' % action)
 
-    with TestPipeline(options=PipelineOptions(streaming=True)) as p:
+    with TestPipeline() as p:
+      # TODO(BEAM-8601): Pass this during pipeline construction.
+      p.options.view_as(StandardOptions).streaming = True
       # Split the test stream into a branch of to-be-processed elements, and
       # a branch of expected results.
       inputs, expected = (
           p
-          | test_stream
+          | read_test_stream
           | beam.MapTuple(
               lambda tag, value: beam.pvalue.TaggedOutput(tag, ('key', value))
               ).with_outputs('input', 'expect'))
@@ -794,7 +835,7 @@
                      t=beam.DoFn.TimestampParam,
                      p=beam.DoFn.PaneInfoParam: (
                          k,
-                         self._windowed_value_info(WindowedValue(
+                         _windowed_value_info(WindowedValue(
                              vs, windows=[window], timestamp=t, pane_info=p))))
           # Place outputs back into the global window to allow flattening
           # and share a single state in Check.
@@ -805,7 +846,17 @@
       tagged_outputs = (
           outputs | beam.MapTuple(lambda key, value: (key, ('actual', value))))
       # pylint: disable=expression-not-assigned
-      (tagged_expected, tagged_outputs) | beam.Flatten() | beam.ParDo(Check())
+      ([tagged_expected, tagged_outputs]
+       | beam.Flatten()
+       | beam.ParDo(Check(self.allow_out_of_order)))
+
+
+class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+  allow_out_of_order = False
+
+
+class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+  allow_out_of_order = True
 
 
 TRANSCRIPT_TEST_FILE = os.path.join(
@@ -814,6 +865,7 @@
 if os.path.exists(TRANSCRIPT_TEST_FILE):
   TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
   TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
+  WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py
index e63f64e..65dbcae 100644
--- a/sdks/python/apache_beam/utils/subprocess_server.py
+++ b/sdks/python/apache_beam/utils/subprocess_server.py
@@ -174,8 +174,9 @@
     elif '.dev' in beam_version:
       # TODO: Attempt to use nightly snapshots?
       raise RuntimeError(
-          'Please build the server with \n  cd %s; ./gradlew %s' % (
-              os.path.abspath(project_root), gradle_target))
+          ('%s not found. '
+           'Please build the server with \n  cd %s; ./gradlew %s') % (
+               local_path, os.path.abspath(project_root), gradle_target))
     else:
       return cls.path_to_maven_jar(
           artifact_id, cls.BEAM_GROUP_ID, beam_version, cls.APACHE_REPOSITORY)
diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py
index 9bccdfd..a3f3abf 100644
--- a/sdks/python/apache_beam/utils/timestamp.py
+++ b/sdks/python/apache_beam/utils/timestamp.py
@@ -25,6 +25,7 @@
 
 import datetime
 import functools
+import time
 from builtins import object
 
 import dateutil.parser
@@ -76,6 +77,10 @@
     return Timestamp(seconds)
 
   @staticmethod
+  def now():
+    return Timestamp(seconds=time.time())
+
+  @staticmethod
   def _epoch_datetime_utc():
     return datetime.datetime.fromtimestamp(0, pytz.utc)
 
@@ -173,6 +178,8 @@
     return self + other
 
   def __sub__(self, other):
+    if isinstance(other, Timestamp):
+      return Duration(micros=self.micros - other.micros)
     other = Duration.of(other)
     return Timestamp(micros=self.micros - other.micros)
 
diff --git a/sdks/python/apache_beam/utils/timestamp_test.py b/sdks/python/apache_beam/utils/timestamp_test.py
index d26d561..2a4d454 100644
--- a/sdks/python/apache_beam/utils/timestamp_test.py
+++ b/sdks/python/apache_beam/utils/timestamp_test.py
@@ -100,6 +100,7 @@
     self.assertEqual(Timestamp(123) - Duration(456), -333)
     self.assertEqual(Timestamp(1230) % 456, 318)
     self.assertEqual(Timestamp(1230) % Duration(456), 318)
+    self.assertEqual(Timestamp(123) - Timestamp(100), 23)
 
     # Check that direct comparison of Timestamp and Duration is allowed.
     self.assertTrue(Duration(123) == Timestamp(123))
@@ -116,6 +117,7 @@
     self.assertEqual((Timestamp(123) - Duration(456)).__class__, Timestamp)
     self.assertEqual((Timestamp(1230) % 456).__class__, Duration)
     self.assertEqual((Timestamp(1230) % Duration(456)).__class__, Duration)
+    self.assertEqual((Timestamp(123) - Timestamp(100)).__class__, Duration)
 
     # Unsupported operations.
     with self.assertRaises(TypeError):
@@ -159,6 +161,10 @@
     self.assertEqual('Timestamp(-999999999)',
                      str(Timestamp(-999999999)))
 
+  def test_now(self):
+    now = Timestamp.now()
+    self.assertTrue(isinstance(now, Timestamp))
+
 
 class DurationTest(unittest.TestCase):
 
diff --git a/sdks/python/container/base_image_requirements.txt b/sdks/python/container/base_image_requirements.txt
index d3931e7..359d6e5 100644
--- a/sdks/python/container/base_image_requirements.txt
+++ b/sdks/python/container/base_image_requirements.txt
@@ -40,7 +40,7 @@
 pydot==1.4.1
 pytz==2019.1
 pyvcf==0.6.8;python_version<"3.0"
-pyyaml==3.13
+pyyaml==5.1
 typing==3.6.6
 
 # Setup packages
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index e3794ba..bc77fd8 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -183,6 +183,7 @@
   '_TimerDoFnParam',
   '_BundleFinalizerParam',
   '_RestrictionDoFnParam',
+  '_WatermarkEstimatorParam',
 
   # Sphinx cannot find this py:class reference target
   'typing.Generic',
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index ccf90f6..7eea64c 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -214,6 +214,7 @@
     ext_modules=cythonize([
         'apache_beam/**/*.pyx',
         'apache_beam/coders/coder_impl.py',
+        'apache_beam/metrics/cells.py',
         'apache_beam/metrics/execution.py',
         'apache_beam/runners/common.py',
         'apache_beam/runners/worker/logger.py',
diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle
index f04d28d..3cb4362 100644
--- a/sdks/python/test-suites/portable/common.gradle
+++ b/sdks/python/test-suites/portable/common.gradle
@@ -79,3 +79,22 @@
 task flinkValidatesRunner() {
   dependsOn 'flinkCompatibilityMatrixLoopback'
 }
+
+// TODO(BEAM-8598): Enable on pre-commit.
+task flinkTriggerTranscript() {
+  dependsOn 'setupVirtualenv'
+  dependsOn ':runners:flink:1.9:job-server:shadowJar'
+  doLast {
+    exec {
+      executable 'sh'
+      args '-c', """
+          . ${envdir}/bin/activate \\
+          && cd ${pythonRootDir} \\
+          && pip install -e .[test] \\
+          && python setup.py nosetests \\
+              --tests apache_beam.transforms.trigger_test:WeakTestStreamTranscriptTest \\
+              --test-pipeline-options='--runner=FlinkRunner --environment_type=LOOPBACK --flink_job_server_jar=${project(":runners:flink:1.9:job-server:").shadowJar.archivePath}'
+          """
+    }
+  }
+}
diff --git a/sdks/python/test-suites/portable/py2/build.gradle b/sdks/python/test-suites/portable/py2/build.gradle
index 3c1548d..5d967e4 100644
--- a/sdks/python/test-suites/portable/py2/build.gradle
+++ b/sdks/python/test-suites/portable/py2/build.gradle
@@ -39,6 +39,8 @@
   dependsOn ':runners:flink:1.9:job-server:shadowJar'
   dependsOn portableWordCountFlinkRunnerBatch
   dependsOn portableWordCountFlinkRunnerStreaming
+  dependsOn ':runners:spark:job-server:shadowJar'
+  dependsOn portableWordCountSparkRunnerBatch
 }
 
 // TODO: Move the rest of this file into ../common.gradle.
diff --git a/sdks/python/test-suites/portable/py35/build.gradle b/sdks/python/test-suites/portable/py35/build.gradle
index 1b2cb4f..88b4e2f 100644
--- a/sdks/python/test-suites/portable/py35/build.gradle
+++ b/sdks/python/test-suites/portable/py35/build.gradle
@@ -36,4 +36,6 @@
     dependsOn ':runners:flink:1.9:job-server:shadowJar'
     dependsOn portableWordCountFlinkRunnerBatch
     dependsOn portableWordCountFlinkRunnerStreaming
+    dependsOn ':runners:spark:job-server:shadowJar'
+    dependsOn portableWordCountSparkRunnerBatch
 }
diff --git a/sdks/python/test-suites/portable/py36/build.gradle b/sdks/python/test-suites/portable/py36/build.gradle
index 475e110..496777d 100644
--- a/sdks/python/test-suites/portable/py36/build.gradle
+++ b/sdks/python/test-suites/portable/py36/build.gradle
@@ -36,4 +36,6 @@
     dependsOn ':runners:flink:1.9:job-server:shadowJar'
     dependsOn portableWordCountFlinkRunnerBatch
     dependsOn portableWordCountFlinkRunnerStreaming
+    dependsOn ':runners:spark:job-server:shadowJar'
+    dependsOn portableWordCountSparkRunnerBatch
 }
diff --git a/sdks/python/test-suites/portable/py37/build.gradle b/sdks/python/test-suites/portable/py37/build.gradle
index 912b316..924de81 100644
--- a/sdks/python/test-suites/portable/py37/build.gradle
+++ b/sdks/python/test-suites/portable/py37/build.gradle
@@ -36,4 +36,6 @@
     dependsOn ':runners:flink:1.9:job-server:shadowJar'
     dependsOn portableWordCountFlinkRunnerBatch
     dependsOn portableWordCountFlinkRunnerStreaming
+    dependsOn ':runners:spark:job-server:shadowJar'
+    dependsOn portableWordCountSparkRunnerBatch
 }
diff --git a/website/Gemfile b/website/Gemfile
index 4a08725..1050303 100644
--- a/website/Gemfile
+++ b/website/Gemfile
@@ -20,7 +20,7 @@
 
 source 'https://rubygems.org'
 
-gem 'jekyll', '3.2'
+gem 'jekyll', '3.6.3'
 
 # Jekyll plugins
 group :jekyll_plugins do
diff --git a/website/Gemfile.lock b/website/Gemfile.lock
index e94f132..9db2ebe 100644
--- a/website/Gemfile.lock
+++ b/website/Gemfile.lock
@@ -13,7 +13,7 @@
     concurrent-ruby (1.1.4)
     ethon (0.11.0)
       ffi (>= 1.3.0)
-    ffi (1.9.25)
+    ffi (1.11.1)
     forwardable-extended (2.6.0)
     html-proofer (3.9.3)
       activesupport (>= 4.2, < 6.0)
@@ -26,15 +26,16 @@
       yell (~> 2.0)
     i18n (0.9.5)
       concurrent-ruby (~> 1.0)
-    jekyll (3.2.0)
+    jekyll (3.6.3)
+      addressable (~> 2.4)
       colorator (~> 1.0)
       jekyll-sass-converter (~> 1.0)
       jekyll-watch (~> 1.1)
-      kramdown (~> 1.3)
-      liquid (~> 3.0)
+      kramdown (~> 1.14)
+      liquid (~> 4.0)
       mercenary (~> 0.3.3)
       pathutil (~> 0.9)
-      rouge (~> 1.7)
+      rouge (>= 1.7, < 3)
       safe_yaml (~> 1.0)
     jekyll-redirect-from (0.11.0)
       jekyll (>= 2.0)
@@ -45,29 +46,27 @@
     jekyll_github_sample (0.3.1)
       activesupport (~> 4.0)
       jekyll (~> 3.0)
-    kramdown (1.16.2)
-    liquid (3.0.6)
-    listen (3.1.5)
-      rb-fsevent (~> 0.9, >= 0.9.4)
-      rb-inotify (~> 0.9, >= 0.9.7)
-      ruby_dep (~> 1.2)
+    kramdown (1.17.0)
+    liquid (4.0.3)
+    listen (3.2.0)
+      rb-fsevent (~> 0.10, >= 0.10.3)
+      rb-inotify (~> 0.9, >= 0.9.10)
     mercenary (0.3.6)
     mini_portile2 (2.3.0)
     minitest (5.11.3)
     nokogiri (1.8.5)
       mini_portile2 (~> 2.3.0)
     parallel (1.12.1)
-    pathutil (0.16.1)
+    pathutil (0.16.2)
       forwardable-extended (~> 2.6)
     public_suffix (3.0.3)
     rake (12.3.0)
-    rb-fsevent (0.10.2)
-    rb-inotify (0.9.10)
-      ffi (>= 0.5.0, < 2)
-    rouge (1.11.1)
-    ruby_dep (1.5.0)
-    safe_yaml (1.0.4)
-    sass (3.5.5)
+    rb-fsevent (0.10.3)
+    rb-inotify (0.10.0)
+      ffi (~> 1.0)
+    rouge (2.2.1)
+    safe_yaml (1.0.5)
+    sass (3.7.4)
       sass-listen (~> 4.0.0)
     sass-listen (4.0.0)
       rb-fsevent (~> 0.9, >= 0.9.4)
@@ -85,7 +84,7 @@
 DEPENDENCIES
   activesupport (< 5.0.0.0)
   html-proofer
-  jekyll (= 3.2)
+  jekyll (= 3.6.3)
   jekyll-redirect-from
   jekyll-sass-converter
   jekyll_github_sample
diff --git a/website/src/_includes/section-menu/documentation.html b/website/src/_includes/section-menu/documentation.html
index 529c264e..ed776ea 100644
--- a/website/src/_includes/section-menu/documentation.html
+++ b/website/src/_includes/section-menu/documentation.html
@@ -215,6 +215,7 @@
           <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/distinct/">Distinct</a></li>
           <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupbykey/">GroupByKey</a></li>
           <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupintobatches/">GroupIntoBatches</a></li>
+          <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/hllcount/">HllCount</a></li>
           <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/latest/">Latest</a></li>
           <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/max/">Max</a></li>
           <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/mean/">Mean</a></li>
diff --git a/website/src/documentation/programming-guide.md b/website/src/documentation/programming-guide.md
index d78b609..9d644ae 100644
--- a/website/src/documentation/programming-guide.md
+++ b/website/src/documentation/programming-guide.md
@@ -1932,8 +1932,8 @@
 suffix ".csv" in the given location:
 
 ```java
-p.apply(“ReadFromText”,
-    TextIO.read().from("protocol://my_bucket/path/to/input-*.csv");
+p.apply("ReadFromText",
+    TextIO.read().from("protocol://my_bucket/path/to/input-*.csv"));
 ```
 
 ```py
diff --git a/website/src/documentation/transforms/java/aggregation/approximateunique.md b/website/src/documentation/transforms/java/aggregation/approximateunique.md
index 9b3e6d0..448c0ee 100644
--- a/website/src/documentation/transforms/java/aggregation/approximateunique.md
+++ b/website/src/documentation/transforms/java/aggregation/approximateunique.md
@@ -35,6 +35,8 @@
 See [BEAM-7703](https://issues.apache.org/jira/browse/BEAM-7703) for updates.
 
 ## Related transforms 
+* [HllCount]({{ site.baseurl }}/documentation/transforms/java/aggregation/hllcount)
+  estimates the number of distinct elements and creates re-aggregatable sketches using the HyperLogLog++ algorithm.
 * [Count]({{ site.baseurl }}/documentation/transforms/java/aggregation/count)
   counts the number of elements within each aggregation.
 * [Distinct]({{ site.baseurl }}/documentation/transforms/java/aggregation/distinct)
\ No newline at end of file
diff --git a/website/src/documentation/transforms/java/aggregation/hllcount.md b/website/src/documentation/transforms/java/aggregation/hllcount.md
new file mode 100644
index 0000000..506a8dc
--- /dev/null
+++ b/website/src/documentation/transforms/java/aggregation/hllcount.md
@@ -0,0 +1,77 @@
+---
+layout: section
+title: "HllCount"
+permalink: /documentation/transforms/java/aggregation/hllcount/
+section_menu: section-menu/documentation.html
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+# Latest
+<table align="left">
+    <a target="_blank" class="button"
+        href="https://beam.apache.org/releases/javadoc/current/index.html?org/apache/beam/sdk/extensions/zetasketch/HllCount.html">
+      <img src="https://beam.apache.org/images/logos/sdks/java.png" width="20px" height="20px"
+           alt="Javadoc" />
+     Javadoc
+    </a>
+</table>
+<br>
+
+Estimates the number of distinct elements in a data stream using the
+[HyperLogLog++ algorithm](http://static.googleusercontent.com/media/research.google.com/en/us/pubs/archive/40671.pdf).
+The respective transforms to create and merge sketches, and to extract from them, are:
+
+* `HllCount.Init` aggregates inputs into HLL++ sketches.
+* `HllCount.MergePartial` merges HLL++ sketches into a new sketch.
+* `HllCount.Extract` extracts the estimated count of distinct elements from HLL++ sketches.
+
+You can read more about what a sketch is at https://github.com/google/zetasketch.
+
+## Examples
+**Example 1**: creates a long-type sketch for a `PCollection<Long>` with a custom precision:
+```java
+ PCollection<Long> input = ...;
+ int p = ...;
+ PCollection<byte[]> sketch = input.apply(HllCount.Init.forLongs().withPrecision(p).globally());
+```
+
+**Example 2**: creates a bytes-type sketch for a `PCollection<KV<String, byte[]>>`:
+```java
+ PCollection<KV<String, byte[]>> input = ...;
+ PCollection<KV<String, byte[]>> sketch = input.apply(HllCount.Init.forBytes().perKey());
+```
+
+**Example 3**: merges existing sketches in a `PCollection<byte[]>` into a new sketch,
+which summarizes the union of the inputs that were aggregated in the merged sketches:
+```java
+ PCollection<byte[]> sketches = ...;
+ PCollection<byte[]> mergedSketch = sketches.apply(HllCount.MergePartial.globally());
+```
+
+**Example 4**: estimates the count of distinct elements in a `PCollection<String>`:
+```java
+ PCollection<String> input = ...;
+ PCollection<Long> countDistinct =
+     input.apply(HllCount.Init.forStrings().globally()).apply(HllCount.Extract.globally());
+```
+
+**Example 5**: extracts the count distinct estimate from an existing sketch:
+```java
+ PCollection<byte[]> sketch = ...;
+ PCollection<Long> countDistinct = sketch.apply(HllCount.Extract.globally());
+```
+
+## Related transforms
+* [ApproximateUnique]({{ site.baseurl }}/documentation/transforms/java/aggregation/approximateunique)
+  estimates the number of distinct elements or values in key-value pairs (but does not expose sketches; also less accurate than `HllCount`).
\ No newline at end of file
diff --git a/website/src/documentation/transforms/java/index.md b/website/src/documentation/transforms/java/index.md
index b36e305..71b3721 100644
--- a/website/src/documentation/transforms/java/index.md
+++ b/website/src/documentation/transforms/java/index.md
@@ -58,6 +58,7 @@
   <tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupbykey">GroupByKey</a></td><td>Takes a keyed collection of elements and produces a collection where each element 
   consists of a key and all values associated with that key.</td></tr>
   <tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupintobatches">GroupIntoBatches</a></td><td>Batches values associated with keys into <code>Iterable</code> batches of some size. Each batch contains elements associated with a specific key.</td></tr>
+  <tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/hllcount">HllCount</a></td><td>Estimates the number of distinct elements and creates re-aggregatable sketches using the HyperLogLog++ algorithm.</td></tr>
   <tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/latest">Latest</a></td><td>Selects the latest element within each aggregation according to the implicit timestamp.</td></tr>
   <tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/max">Max</a></td><td>Outputs the maximum element within each aggregation.</td></tr>  
   <tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/mean">Mean</a></td><td>Computes the average within each aggregation.</td></tr>
diff --git a/website/src/roadmap/connectors-multi-sdk.md b/website/src/roadmap/connectors-multi-sdk.md
index 5ab6877..a0ccea3 100644
--- a/website/src/roadmap/connectors-multi-sdk.md
+++ b/website/src/roadmap/connectors-multi-sdk.md
@@ -28,18 +28,73 @@
 
 # Cross-language transforms
 
-As an added benefit of Beam portability efforts, in the future, we’ll be
-able to utilize Beam transforms across languages. This has many benefits.
-For example.
+_Last updated on November 2019._
 
-* Beam pipelines written using Python and Go SDKs will be able to utilize
-the vast selection of connectors that are currently available for Java SDK.
-* Java SDK will be able to utilize connectors for systems that only offer a
-Python API.
-* Go SDK, will be able to utilize connectors currently available for Java and
-Python SDKs.
-* Connector authors will be able to implement new Beam connectors using a 
-language of choice and utilize these connectors from other languages reducing
-the maintenance and support efforts.
+As an added benefit of Beam portability effort, we are able to utilize Beam transforms across SDKs. This has many benefits.
 
-See [Beam portability framework roadmap](https://beam.apache.org/roadmap/portability/) for more details.
+* Connector sharing across SDKs. For example,
+  + Beam pipelines written using Python and Go SDKs will be able to utilize the vast selection of connectors that are currently implemented for Java SDK.
+  + Java SDK will be able to utilize connectors for systems that only offer a Python API.
+  + Go SDK, will be able to utilize connectors currently available for Java and Python SDKs.
+* Ease of developing and maintaining Beam transforms - in general, with cross-language transforms, Beam transform authors will be able to implement new Beam transforms using a 
+language of choice and utilize these transforms from other languages reducing the maintenance and support overheads.
+* [Beam SQL](https://beam.apache.org/documentation/dsls/sql/overview/), that is currently only available to Java SDK, will become available to Python and Go SDKs.
+* [Beam TFX transforms](https://www.tensorflow.org/tfx/transform/get_started), that are currently only available to Beam Python SDK pipelines will become available to Java and Go SDKs.
+
+## Completed and Ongoing Efforts
+
+Many efforts related to cross-language transforms are currently in flux. Some of the completed and ongoing efforts are given below.
+
+### Cross-language transforms API and expansion service
+
+Work related to developing/updating the cross-language transforms API for Java/Python/Go SDKs and work related to cross-language transform expansion services.
+
+* Basic API for Java SDK - completed
+* Basic API for Python SDK - completed
+* Basic API for Go SDK - Not started
+* Basic cross-language transform expansion service for Java and Python SDKs - completed
+* Artifact staging - In progress - [email thread](https://lists.apache.org/thread.html/6fcee7047f53cf1c0636fb65367ef70842016d57effe2e5795c4137d@%3Cdev.beam.apache.org%3E), [doc](https://docs.google.com/document/d/1XaiNekAY2sptuQRIXpjGAyaYdSc-wlJ-VKjl04c8N48/edit#heading=h.900gc947qrw8)
+
+### Support for Flink runner
+
+Work related to making cross-language transforms available for Flink runner.
+
+* Basic support for executing cross-language transforms on portable Flink runner - completed
+
+### Support for Dataflow runner
+
+Work related to making cross-language transforms available for Dataflow runner.
+
+* Basic support for executing cross-language transforms on Dataflow runner 
+  + This work requires updates to Dataflow service's job submission and job execution logic. This is currently being developed at Google.
+
+### Support for Direct runner
+
+Work related to making cross-language transforms available on Direct runner
+
+* Basic support for executing cross-language transforms on portable Direct runner - Not started
+
+### Connector/transform support
+
+Ongoing and planned work related to making existing connectors/transforms available to other SDKs through the cross-language transforms framework.
+
+* Java KafkIO - In progress - [BEAM-7029](https://issues.apache.org/jira/browse/BEAM-7029)
+* Java PubSubIO - In progress - [BEAM-7738](https://issues.apache.org/jira/browse/BEAM-7738)
+
+### Portable Beam schema
+
+Portable Beam schema support will provide a generalized mechanism for serializing and transferring data across language boundaries which will be extremely useful for pipelines that employ cross-language transforms.
+
+* Make row coder a standard coder and implement in python - In progress - [BEAM-7886](https://issues.apache.org/jira/browse/BEAM-7886)
+
+### Integration/Performance testing
+
+* Add an integration test suite for cross-language transforms on Flink runner - In progress - [BEAM-6683](https://issues.apache.org/jira/browse/BEAM-6683)
+
+### Documentation
+
+Work related to adding documenting on cross-language transforms to Beam Website.
+
+* Document cross-language transforms API for Java/Python - Not started
+* Document API for making existing transforms available as cross-language transforms for Java/Python - Not started
+
diff --git a/website/src/roadmap/index.md b/website/src/roadmap/index.md
index 21aee01..6c1a1d6 100644
--- a/website/src/roadmap/index.md
+++ b/website/src/roadmap/index.md
@@ -32,18 +32,20 @@
 
 ## Portability Framework
 
-Portability is the primary Beam vision: running pipelines authors with _any SDK_
+Portability is the primary Beam vision: running pipelines authored with _any SDK_
 on _any runner_. This is a cross-cutting effort across Java, Python, and Go, 
-and every Beam runner.
+and every Beam runner. Portability is currently supported on the
+[Flink]({{site.baseurl}}/documentation/runners/flink/)
+and [Spark]({{site.baseurl}}/documentation/runners/spark/) runners.
 
 See the details on the [Portability Roadmap]({{site.baseurl}}/roadmap/portability/)
 
-## Python on Flink
+## Cross-language transforms
 
-A major highlight of the portability effort is the effort in running Python pipelines
-the Flink runner.
-
-You can [follow the instructions to try it out]({{site.baseurl}}/roadmap/portability/#python-on-flink)
+As a benefit of the portability effort, we are able to utilize Beam transforms across SDKs.
+Examples include using Java connectors and Beam SQL from Python or Go pipelines
+or Beam TFX transforms from Java and Go.
+For details see [Roadmap for multi-SDK efforts]({{ site.baseurl }}/roadmap/connectors-multi-sdk/).
 
 ## Go SDK