Adds support for SDF in ULR and the Java SDK.
diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index b769552..a4017d8 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -216,13 +216,22 @@
     google.protobuf.DoubleValue fraction_of_work = 5;
   }
 
+  // An an Application should be scheduled after a delay.
+  message DelayedApplication {
+    // (Required) The delay in seconds.
+    double delay_sec = 1;
+
+    // (Required) The application that should be scheduled.
+    Application application = 2;
+  }
+
   // Root applications that should replace the current bundle.
   repeated Application primary_roots = 1;
 
   // Root applications that have been removed from the current bundle and
   // have to be executed in a separate bundle (e.g. in parallel on a different
   // worker, or after the current bundle completes, etc.)
-  repeated Application residual_roots = 2;
+  repeated DelayedApplication residual_roots = 2;
 }
 
 // A request to process a given bundle.
diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto
index 44e3f42..a58769a 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -263,9 +263,27 @@
   }
   // Payload for all of these: ParDoPayload containing the user's SDF
   enum SplittableParDoComponents {
+    // Pairs the input element with its initial restriction.
+    // Input: element; output: KV(element, restriction).
     PAIR_WITH_RESTRICTION = 0 [(beam_urn) = "beam:transform:sdf_pair_with_restriction:v1"];
+
+    // Splits the restriction inside an element/restriction pair.
+    // Input: KV(element, restriction); output: KV(element, restriction).
     SPLIT_RESTRICTION = 1 [(beam_urn) = "beam:transform:sdf_split_restriction:v1"];
+
+    // Applies the DoFn to every element/restriction pair in a uniquely keyed
+    // collection, in a splittable fashion.
+    // Input: KV(bytes, KV(element, restriction)); output: DoFn's output.
+    // The first "bytes" is an opaque unique key using the standard bytes coder.
+    // Typically a runner would rewrite this into a runner-specific grouping
+    // operation supporting state and timers, followed by PROCESS_ELEMENTS,
+    // with some runner-specific glue code in between.
     PROCESS_KEYED_ELEMENTS = 2 [(beam_urn) = "beam:transform:sdf_process_keyed_elements:v1"];
+
+    // Like PROCESS_KEYED_ELEMENTS, but without the unique key - just elements
+    // and restrictions.
+    // Input: KV(element, restriction); output: DoFn's output.
+    PROCESS_ELEMENTS = 3 [(beam_urn) = "beam:transform:sdf_process_elements:v1"];
   }
 }
 
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
index 8db73df..8db3156 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
@@ -470,14 +470,14 @@
     if (doFn instanceof ProcessFn) {
 
       @SuppressWarnings("unchecked")
-      StateInternalsFactory<String> stateInternalsFactory =
-          (StateInternalsFactory<String>) this.currentKeyStateInternals.getFactory();
+      StateInternalsFactory<byte[]> stateInternalsFactory =
+          (StateInternalsFactory<byte[]>) this.currentKeyStateInternals.getFactory();
 
       @SuppressWarnings({ "rawtypes", "unchecked" })
       ProcessFn<InputT, OutputT, Object, RestrictionTracker<Object, Object>>
         splittableDoFn = (ProcessFn) doFn;
       splittableDoFn.setStateInternalsFactory(stateInternalsFactory);
-      TimerInternalsFactory<String> timerInternalsFactory = key -> currentKeyTimerInternals;
+      TimerInternalsFactory<byte[]> timerInternalsFactory = key -> currentKeyTimerInternals;
       splittableDoFn.setTimerInternalsFactory(timerInternalsFactory);
       splittableDoFn.setProcessElementInvoker(
           new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java
index f57eb99..00554f7 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java
@@ -40,6 +40,7 @@
       ImmutableMap.<String, EnvironmentIdExtractor>builder()
           .put(PTransformTranslation.COMBINE_TRANSFORM_URN, Environments::combineExtractor)
           .put(PTransformTranslation.PAR_DO_TRANSFORM_URN, Environments::parDoExtractor)
+          .put(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN, Environments::parDoExtractor)
           .put(PTransformTranslation.READ_TRANSFORM_URN, Environments::readExtractor)
           .put(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN, Environments::windowExtractor)
           .build();
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
index b1d6015..d3a6281 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
@@ -85,6 +85,8 @@
       getUrn(StandardPTransforms.Composites.WRITE_FILES);
   public static final String SPLITTABLE_PROCESS_KEYED_URN =
       getUrn(SplittableParDoComponents.PROCESS_KEYED_ELEMENTS);
+  public static final String SPLITTABLE_PROCESS_ELEMENTS_URN =
+      getUrn(SplittableParDoComponents.PROCESS_ELEMENTS);
 
   private static final Map<Class<? extends PTransform>, TransformPayloadTranslator>
       KNOWN_PAYLOAD_TRANSLATORS = loadTransformPayloadTranslators();
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
index a8c6bb6..c00b63d 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
@@ -172,4 +172,8 @@
   public Environment getEnvironment(String environmentId) {
     return components.getEnvironmentsOrThrow(environmentId);
   }
+
+  public Components getComponents() {
+    return components;
+  }
 }
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
index 8bd3960..f8a43e9 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
@@ -23,6 +23,7 @@
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Maps;
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -148,7 +149,7 @@
             .invokeGetRestrictionCoder(input.getPipeline().getCoderRegistry());
     Coder<KV<InputT, RestrictionT>> splitCoder = KvCoder.of(input.getCoder(), restrictionCoder);
 
-    PCollection<KV<String, KV<InputT, RestrictionT>>> keyedRestrictions =
+    PCollection<KV<byte[], KV<InputT, RestrictionT>>> keyedRestrictions =
         input
             .apply(
                 "Pair with initial restriction",
@@ -198,7 +199,7 @@
    * {@link KV KVs} keyed with arbitrary but globally unique keys.
    */
   public static class ProcessKeyedElements<InputT, OutputT, RestrictionT>
-      extends PTransform<PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple> {
+      extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
     private final DoFn<InputT, OutputT> fn;
     private final Coder<InputT> elementCoder;
     private final Coder<RestrictionT> restrictionCoder;
@@ -269,7 +270,7 @@
     }
 
     @Override
-    public PCollectionTuple expand(PCollection<KV<String, KV<InputT, RestrictionT>>> input) {
+    public PCollectionTuple expand(PCollection<KV<byte[], KV<InputT, RestrictionT>>> input) {
       return createPrimitiveOutputFor(
           input, fn, mainOutputTag, additionalOutputTags, outputTagsToCoders, windowingStrategy);
     }
@@ -395,10 +396,10 @@
    * collection is effectively the same elements as input, but the per-key state and timers are now
    * effectively per-element.
    */
-  private static class RandomUniqueKeyFn<T> implements SerializableFunction<T, String> {
+  private static class RandomUniqueKeyFn<T> implements SerializableFunction<T, byte[]> {
     @Override
-    public String apply(T input) {
-      return UUID.randomUUID().toString();
+    public byte[] apply(T input) {
+      return UUID.randomUUID().toString().getBytes(StandardCharsets.UTF_8);
     }
   }
 
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java
index 4e913f9..25129dc 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java
@@ -248,7 +248,7 @@
   }
 
   private static void validateExecutableStage(
-      String id, PTransform transform, Components outerComponentsIgnored) throws Exception {
+      String id, PTransform transform, Components outerComponents) throws Exception {
     ExecutableStagePayload payload =
         ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
 
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
index 8c360ef..f560b51 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
@@ -39,13 +39,13 @@
  */
 public class ProcessFnRunner<InputT, OutputT, RestrictionT>
     implements PushbackSideInputDoFnRunner<
-        KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> {
-  private final DoFnRunner<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> underlying;
+        KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
+  private final DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> underlying;
   private final Collection<PCollectionView<?>> views;
   private final ReadyCheckingSideInputReader sideInputReader;
 
   public ProcessFnRunner(
-      DoFnRunner<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> underlying,
+      DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> underlying,
       Collection<PCollectionView<?>> views,
       ReadyCheckingSideInputReader sideInputReader) {
     this.underlying = underlying;
@@ -54,7 +54,7 @@
   }
 
   @Override
-  public DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> getFn() {
+  public DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> getFn() {
     return underlying.getFn();
   }
 
@@ -64,9 +64,9 @@
   }
 
   @Override
-  public Iterable<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>>
+  public Iterable<WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>>
       processElementInReadyWindows(
-          WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>> windowedKWI) {
+          WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> windowedKWI) {
     checkTrivialOuterWindows(windowedKWI);
     BoundedWindow window = getUnderlyingWindow(windowedKWI.getValue());
     if (!isReady(window)) {
@@ -88,7 +88,7 @@
   }
 
   private static <T> void checkTrivialOuterWindows(
-      WindowedValue<KeyedWorkItem<String, T>> windowedKWI) {
+      WindowedValue<KeyedWorkItem<byte[], T>> windowedKWI) {
     // In practice it will be in 0 or 1 windows (ValueInEmptyWindows or ValueInGlobalWindow)
     Collection<? extends BoundedWindow> outerWindows = windowedKWI.getWindows();
     if (!outerWindows.isEmpty()) {
@@ -104,7 +104,7 @@
     }
   }
 
-  private static <T> BoundedWindow getUnderlyingWindow(KeyedWorkItem<String, T> kwi) {
+  private static <T> BoundedWindow getUnderlyingWindow(KeyedWorkItem<byte[], T> kwi) {
     if (Iterables.isEmpty(kwi.elementsIterable())) {
       // ProcessFn sets only a single timer.
       TimerData timer = Iterables.getOnlyElement(kwi.timersIterable());
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
index 45f8b4b..fe40afd 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
@@ -28,9 +28,9 @@
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.core.construction.SplittableParDo;
 import org.apache.beam.runners.core.construction.SplittableParDo.ProcessKeyedElements;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
@@ -100,14 +100,14 @@
   /** Overrides a {@link ProcessKeyedElements} into {@link SplittableProcessViaKeyedWorkItems}. */
   public static class OverrideFactory<InputT, OutputT, RestrictionT>
       implements PTransformOverrideFactory<
-          PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple,
+          PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple,
           ProcessKeyedElements<InputT, OutputT, RestrictionT>> {
     @Override
     public PTransformReplacement<
-            PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple>
+            PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple>
         getReplacementTransform(
             AppliedPTransform<
-                    PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple,
+                    PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple,
                     ProcessKeyedElements<InputT, OutputT, RestrictionT>>
                 transform) {
       return PTransformReplacement.of(
@@ -127,7 +127,7 @@
    * method for a splittable {@link DoFn}.
    */
   public static class SplittableProcessViaKeyedWorkItems<InputT, OutputT, RestrictionT>
-      extends PTransform<PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple> {
+      extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
     private final ProcessKeyedElements<InputT, OutputT, RestrictionT> original;
 
     public SplittableProcessViaKeyedWorkItems(
@@ -136,13 +136,13 @@
     }
 
     @Override
-    public PCollectionTuple expand(PCollection<KV<String, KV<InputT, RestrictionT>>> input) {
+    public PCollectionTuple expand(PCollection<KV<byte[], KV<InputT, RestrictionT>>> input) {
       return input
           .apply(new GBKIntoKeyedWorkItems<>())
           .setCoder(
               KeyedWorkItemCoder.of(
-                  StringUtf8Coder.of(),
-                  ((KvCoder<String, KV<InputT, RestrictionT>>) input.getCoder()).getValueCoder(),
+                  ByteArrayCoder.of(),
+                  ((KvCoder<byte[], KV<InputT, RestrictionT>>) input.getCoder()).getValueCoder(),
                   input.getWindowingStrategy().getWindowFn().windowCoder()))
           .apply(new ProcessElements<>(original));
     }
@@ -152,7 +152,7 @@
   public static class ProcessElements<
           InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
       extends PTransform<
-          PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>, PCollectionTuple> {
+          PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
     private final ProcessKeyedElements<InputT, OutputT, RestrictionT> original;
 
     public ProcessElements(ProcessKeyedElements<InputT, OutputT, RestrictionT> original) {
@@ -186,7 +186,7 @@
 
     @Override
     public PCollectionTuple expand(
-        PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>> input) {
+        PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> input) {
       return ProcessKeyedElements.createPrimitiveOutputFor(
           input,
           original.getFn(),
@@ -212,7 +212,7 @@
   @VisibleForTesting
   public static class ProcessFn<
           InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
-      extends DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> {
+      extends DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
     /**
      * The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is
      * acquired during the first {@link DoFn.ProcessElement} call for each element and restriction,
@@ -245,8 +245,8 @@
     private final Coder<RestrictionT> restrictionCoder;
     private final WindowingStrategy<InputT, ?> inputWindowingStrategy;
 
-    private transient @Nullable StateInternalsFactory<String> stateInternalsFactory;
-    private transient @Nullable TimerInternalsFactory<String> timerInternalsFactory;
+    private transient @Nullable StateInternalsFactory<byte[]> stateInternalsFactory;
+    private transient @Nullable TimerInternalsFactory<byte[]> timerInternalsFactory;
     private transient @Nullable SplittableProcessElementInvoker<
             InputT, OutputT, RestrictionT, TrackerT>
         processElementInvoker;
@@ -270,11 +270,11 @@
       this.restrictionTag = StateTags.value("restriction", restrictionCoder);
     }
 
-    public void setStateInternalsFactory(StateInternalsFactory<String> stateInternalsFactory) {
+    public void setStateInternalsFactory(StateInternalsFactory<byte[]> stateInternalsFactory) {
       this.stateInternalsFactory = stateInternalsFactory;
     }
 
-    public void setTimerInternalsFactory(TimerInternalsFactory<String> timerInternalsFactory) {
+    public void setTimerInternalsFactory(TimerInternalsFactory<byte[]> timerInternalsFactory) {
       this.timerInternalsFactory = timerInternalsFactory;
     }
 
@@ -322,7 +322,7 @@
 
     @ProcessElement
     public void processElement(final ProcessContext c) {
-      String key = c.element().key();
+      byte[] key = c.element().key();
       StateInternals stateInternals = stateInternalsFactory.stateInternalsForKey(key);
       TimerInternals timerInternals = timerInternalsFactory.timerInternalsForKey(key);
 
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
index 617e557..c0b70d4 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
@@ -31,6 +31,7 @@
 import static org.junit.Assert.assertTrue;
 
 import java.io.Serializable;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -124,7 +125,7 @@
           PositionT,
           TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
       implements AutoCloseable {
-    private final DoFnTester<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> tester;
+    private final DoFnTester<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> tester;
     private Instant currentProcessingTime;
 
     private InMemoryTimerInternals timerInternals;
@@ -200,7 +201,8 @@
 
     void startElement(WindowedValue<KV<InputT, RestrictionT>> windowedValue) throws Exception {
       tester.processElement(
-          KeyedWorkItems.elementsWorkItem("key", Collections.singletonList(windowedValue)));
+          KeyedWorkItems.elementsWorkItem(
+              "key".getBytes(StandardCharsets.UTF_8), Collections.singletonList(windowedValue)));
     }
 
     /**
@@ -219,7 +221,8 @@
       if (timers.isEmpty()) {
         return false;
       }
-      tester.processElement(KeyedWorkItems.timersWorkItem("key", timers));
+      tester.processElement(
+          KeyedWorkItems.timersWorkItem("key".getBytes(StandardCharsets.UTF_8), timers));
       return true;
     }
 
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 6fc14a2..00ab794 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -61,7 +61,7 @@
  * in the direct runner. Currently overrides applications of <a
  * href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a>.
  */
-class ParDoMultiOverrideFactory<InputT, OutputT>
+public class ParDoMultiOverrideFactory<InputT, OutputT>
     implements PTransformOverrideFactory<
         PCollection<? extends InputT>, PCollectionTuple,
         PTransform<PCollection<? extends InputT>, PCollectionTuple>> {
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
index 58adb94..0b7cf33 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
@@ -53,7 +53,7 @@
         PositionT,
         TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
     implements TransformEvaluatorFactory {
-  private final ParDoEvaluatorFactory<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT>
+  private final ParDoEvaluatorFactory<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT>
       delegateFactory;
   private final ScheduledExecutorService ses;
   private final EvaluationContext evaluationContext;
@@ -107,9 +107,9 @@
   }
 
   @SuppressWarnings({"unchecked", "rawtypes"})
-  private TransformEvaluator<KeyedWorkItem<String, KV<InputT, RestrictionT>>> createEvaluator(
+  private TransformEvaluator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> createEvaluator(
       AppliedPTransform<
-              PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>, PCollectionTuple,
+              PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple,
               ProcessElements<InputT, OutputT, RestrictionT, TrackerT>>
           application,
       CommittedBundle<InputT> inputBundle)
@@ -118,15 +118,15 @@
         application.getTransform();
 
     final DoFnLifecycleManagerRemovingTransformEvaluator
-      <KeyedWorkItem<String, KV<InputT, RestrictionT>>> evaluator =
+      <KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> evaluator =
       delegateFactory.createEvaluator(
         (AppliedPTransform) application,
-        (PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>) inputBundle.getPCollection(),
+        (PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>) inputBundle.getPCollection(),
         inputBundle.getKey(),
         application.getTransform().getSideInputs(),
         application.getTransform().getMainOutputTag(),
         application.getTransform().getAdditionalOutputTags().getAll());
-    final ParDoEvaluator<KeyedWorkItem<String, KV<InputT, RestrictionT>>> pde =
+    final ParDoEvaluator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> pde =
       evaluator.getParDoEvaluator();
     final ProcessFn<InputT, OutputT, RestrictionT, TrackerT> processFn =
       (ProcessFn<InputT, OutputT, RestrictionT, TrackerT>)
@@ -176,7 +176,7 @@
   }
 
   private static <InputT, OutputT, RestrictionT>
-      ParDoEvaluator.DoFnRunnerFactory<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT>
+      ParDoEvaluator.DoFnRunnerFactory<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT>
           processFnRunnerFactory() {
     return (options,
         fn,
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java
index ab612a6..d115cd5 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java
@@ -95,7 +95,7 @@
    *
    * @return this table
    */
-  public CopyOnAccessInMemoryStateInternals commit() {
+  public CopyOnAccessInMemoryStateInternals<K> commit() {
     table.commit();
     return this;
   }
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java
index 8f9127b..a01f1cd 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java
@@ -21,6 +21,7 @@
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.value.AutoValue;
+import com.google.common.base.MoreObjects;
 import com.google.common.collect.ImmutableList;
 import java.util.Iterator;
 import javax.annotation.Nonnull;
@@ -117,6 +118,13 @@
       return CommittedImmutableListBundle.create(
           pcollection, key, committedElements, minSoFar, synchronizedCompletionTime);
     }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this)
+          .add("elements", elements.build())
+          .toString();
+    }
   }
 
   @AutoValue
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java
index 1e575f2..e6748ee 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java
@@ -19,17 +19,23 @@
 package org.apache.beam.runners.direct.portable;
 
 import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
 import static com.google.common.collect.Iterables.getOnlyElement;
 import static org.apache.beam.runners.core.construction.SyntheticComponents.uniqueId;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.protobuf.Struct;
 import java.io.File;
+import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.stream.Collectors;
 import javax.annotation.Nullable;
 import org.apache.beam.model.fnexecution.v1.ProvisionApi.ProvisionInfo;
 import org.apache.beam.model.fnexecution.v1.ProvisionApi.Resources;
@@ -41,17 +47,20 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform.Builder;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
 import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec;
 import org.apache.beam.runners.core.construction.ModelCoders;
 import org.apache.beam.runners.core.construction.ModelCoders.KvCoderComponents;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
 import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
 import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
 import org.apache.beam.runners.core.construction.graph.PipelineValidator;
 import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
 import org.apache.beam.runners.core.construction.graph.ProtoOverrides.TransformReplacement;
+import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
 import org.apache.beam.runners.direct.ExecutableGraph;
 import org.apache.beam.runners.direct.portable.artifact.LocalFileSystemArtifactRetrievalService;
 import org.apache.beam.runners.direct.portable.artifact.UnsupportedArtifactRetrievalService;
@@ -73,6 +82,7 @@
 import org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter;
 import org.apache.beam.runners.fnexecution.provisioning.StaticGrpcProvisionService;
 import org.apache.beam.runners.fnexecution.state.GrpcStateService;
+import org.apache.beam.runners.fnexecution.wire.LengthPrefixUnknownCoders;
 import org.apache.beam.sdk.fn.IdGenerators;
 import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -106,15 +116,18 @@
 
   private RunnerApi.Pipeline executable(RunnerApi.Pipeline original) {
     RunnerApi.Pipeline p = original;
-
+    PipelineValidator.validate(p);
     p =
         ProtoOverrides.updateTransform(
-            PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN,
+            PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
             p,
-            new PortableGroupByKeyReplacer());
-    PipelineValidator.validate(p);
-
+            new SplittableProcessKeyedReplacer());
+    p =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, p, new PortableGroupByKeyReplacer());
     p = GreedyPipelineFuser.fuse(p).toPipeline();
+
+    p = foldFeedSDFIntoExecutableStage(p);
     PipelineValidator.validate(p);
 
     return p;
@@ -325,6 +338,219 @@
     }
   }
 
+  /**
+   * Replaces the {@link PTransformTranslation#SPLITTABLE_PROCESS_KEYED_URN} with a {@link
+   * DirectGroupByKey#DIRECT_GBKO_URN} (construct keyed work items) followed by a {@link
+   * SplittableRemoteStageEvaluatorFactory#FEED_SDF_URN} (convert the keyed work items to
+   * element/restriction pairs that later go into {@link
+   * PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN}).
+   */
+  @VisibleForTesting
+  static class SplittableProcessKeyedReplacer implements TransformReplacement {
+    @Override
+    public MessageWithComponents getReplacement(String spkId, ComponentsOrBuilder components) {
+      PTransform spk = components.getTransformsOrThrow(spkId);
+      checkArgument(
+          PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN.equals(spk.getSpec().getUrn()),
+          "URN must be %s, got %s",
+          PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
+          spk.getSpec().getUrn());
+
+      Components.Builder newComponents = Components.newBuilder();
+      newComponents.putAllCoders(components.getCodersMap());
+
+      Builder newPTransform = spk.toBuilder();
+
+      String inputId = getOnlyElement(spk.getInputsMap().values());
+      PCollection input = components.getPcollectionsOrThrow(inputId);
+
+      // This is a Coder<KV<String, KV<ElementT, RestrictionT>>>
+      Coder inputCoder = components.getCodersOrThrow(input.getCoderId());
+      KvCoderComponents kvComponents = ModelCoders.getKvCoderComponents(inputCoder);
+      String windowCoderId =
+          components
+              .getWindowingStrategiesOrThrow(input.getWindowingStrategyId())
+              .getWindowCoderId();
+
+      // === Construct a raw GBK returning KeyedWorkItem's ===
+      String kwiCollectionId =
+          uniqueId(String.format("%s.kwi", spkId), components::containsPcollections);
+      {
+        // This coder isn't actually required for the pipeline to function properly - the KWIs can
+        // be passed around as pure java objects with no coding of the values, but it approximates a
+        // full pipeline.
+        Coder kwiCoder =
+            Coder.newBuilder()
+                .setSpec(
+                    SdkFunctionSpec.newBuilder()
+                        .setSpec(FunctionSpec.newBuilder().setUrn("beam:direct:keyedworkitem:v1")))
+                .addAllComponentCoderIds(
+                    ImmutableList.of(
+                        kvComponents.keyCoderId(), kvComponents.valueCoderId(), windowCoderId))
+                .build();
+        String kwiCoderId =
+            uniqueId(
+                String.format(
+                    "keyed_work_item(%s:%s)",
+                    kvComponents.keyCoderId(), kvComponents.valueCoderId()),
+                components::containsCoders);
+
+        PCollection kwiCollection =
+            input.toBuilder().setUniqueName(kwiCollectionId).setCoderId(kwiCoderId).build();
+        String rawGbkId =
+            uniqueId(String.format("%s/RawGBK", spkId), components::containsTransforms);
+        PTransform rawGbk =
+            PTransform.newBuilder()
+                .setUniqueName(String.format("%s/RawGBK", spk.getUniqueName()))
+                .putAllInputs(spk.getInputsMap())
+                .setSpec(FunctionSpec.newBuilder().setUrn(DirectGroupByKey.DIRECT_GBKO_URN))
+                .putOutputs("output", kwiCollectionId)
+                .build();
+
+        newComponents
+            .putCoders(kwiCoderId, kwiCoder)
+            .putPcollections(kwiCollectionId, kwiCollection)
+            .putTransforms(rawGbkId, rawGbk);
+        newPTransform.addSubtransforms(rawGbkId);
+      }
+
+      // === Construct a "Feed SDF" operation that converts KWI to KV<ElementT, RestrictionT> ===
+      String feedSDFCollectionId =
+          uniqueId(String.format("%s.feed", spkId), components::containsPcollections);
+      {
+        String feedSDFCoderId =
+            uniqueId(String.format("%s/FeedSDF-wire", spkId), components::containsCoders);
+        String elementRestrictionCoderId = kvComponents.valueCoderId();
+        MessageWithComponents feedSDFCoder =
+            LengthPrefixUnknownCoders.forCoder(
+                elementRestrictionCoderId, newComponents.build(), false);
+
+        PCollection feedSDFCollection =
+            input.toBuilder().setUniqueName(feedSDFCollectionId).setCoderId(feedSDFCoderId).build();
+        String feedSDFId =
+            uniqueId(String.format("%s/FeedSDF", spkId), components::containsTransforms);
+        PTransform feedSDF =
+            PTransform.newBuilder()
+                .setUniqueName(String.format("%s/FeedSDF", spk.getUniqueName()))
+                .putInputs("input", kwiCollectionId)
+                .setSpec(
+                    FunctionSpec.newBuilder()
+                        .setUrn(SplittableRemoteStageEvaluatorFactory.FEED_SDF_URN))
+                .putOutputs("output", feedSDFCollectionId)
+                .build();
+
+        newComponents
+            .putCoders(feedSDFCoderId, feedSDFCoder.getCoder())
+            .putAllCoders(feedSDFCoder.getComponents().getCodersMap())
+            .putPcollections(feedSDFCollectionId, feedSDFCollection)
+            .putTransforms(feedSDFId, feedSDF);
+        newPTransform.addSubtransforms(feedSDFId);
+      }
+
+      // === Construct the SPLITTABLE_PROCESS_ELEMENTS operation
+      {
+        String runSDFId =
+            uniqueId(String.format("%s/RunSDF", spkId), components::containsTransforms);
+        PTransform runSDF =
+            PTransform.newBuilder()
+                .setUniqueName(String.format("%s/RunSDF", spk.getUniqueName()))
+                .putInputs("input", feedSDFCollectionId)
+                .setSpec(
+                    FunctionSpec.newBuilder()
+                        .setUrn(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN)
+                        .setPayload(spk.getSpec().getPayload()))
+                .putAllOutputs(spk.getOutputsMap())
+                .build();
+        newComponents.putTransforms(runSDFId, runSDF);
+        newPTransform.addSubtransforms(runSDFId);
+      }
+
+      return MessageWithComponents.newBuilder()
+          .setPtransform(newPTransform.build())
+          .setComponents(newComponents)
+          .build();
+    }
+  }
+
+  /**
+   * Finds FEED_SDF nodes followed by an ExecutableStage and replaces them by a single {@link
+   * SplittableRemoteStageEvaluatorFactory#URN} stage that feeds the ExecutableStage knowing that
+   * the first instruction in the stage is an SDF.
+   */
+  private static Pipeline foldFeedSDFIntoExecutableStage(Pipeline p) {
+    Pipeline.Builder newPipeline = p.toBuilder();
+    Components.Builder newPipelineComponents = newPipeline.getComponentsBuilder();
+
+    QueryablePipeline q = QueryablePipeline.forPipeline(p);
+    String feedSdfUrn = SplittableRemoteStageEvaluatorFactory.FEED_SDF_URN;
+    List<PTransformNode> feedSDFNodes =
+        q.getTransforms()
+            .stream()
+            .filter(node -> node.getTransform().getSpec().getUrn().equals(feedSdfUrn))
+            .collect(Collectors.toList());
+    Map<String, PTransformNode> stageToFeeder = Maps.newHashMap();
+    for (PTransformNode node : feedSDFNodes) {
+      PCollectionNode output = Iterables.getOnlyElement(q.getOutputPCollections(node));
+      PTransformNode consumer = Iterables.getOnlyElement(q.getPerElementConsumers(output));
+      String consumerUrn = consumer.getTransform().getSpec().getUrn();
+      checkState(
+          consumerUrn.equals(ExecutableStage.URN),
+          "Expected all FeedSDF nodes to be consumed by an ExecutableStage, "
+              + "but %s is consumed by %s which is %s",
+          node.getId(),
+          consumer.getId(),
+          consumerUrn);
+      stageToFeeder.put(consumer.getId(), node);
+    }
+
+    // Copy over root transforms except for the excluded FEED_SDF transforms.
+    Set<String> feedSDFIds =
+        feedSDFNodes.stream().map(PTransformNode::getId).collect(Collectors.toSet());
+    newPipeline.clearRootTransformIds();
+    for (String rootId : p.getRootTransformIdsList()) {
+      if (!feedSDFIds.contains(rootId)) {
+        newPipeline.addRootTransformIds(rootId);
+      }
+    }
+    // Copy over all transforms, except FEED_SDF transforms are skipped, and ExecutableStage's
+    // feeding from them are replaced.
+    for (PTransformNode node : q.getTransforms()) {
+      if (feedSDFNodes.contains(node)) {
+        // These transforms are skipped and handled separately.
+        continue;
+      }
+      if (!stageToFeeder.containsKey(node.getId())) {
+        // This transform is unchanged
+        newPipelineComponents.putTransforms(node.getId(), node.getTransform());
+        continue;
+      }
+      // "node" is an ExecutableStage transform feeding from an SDF.
+      PTransformNode feedSDFNode = stageToFeeder.get(node.getId());
+      PCollectionNode rawGBKOutput =
+          Iterables.getOnlyElement(q.getPerElementInputPCollections(feedSDFNode));
+
+      // Replace the ExecutableStage transform.
+      newPipelineComponents.putTransforms(
+          node.getId(),
+          node.getTransform()
+              .toBuilder()
+              .mergeSpec(
+                  // Change URN from ExecutableStage.URN to URN of the ULR's splittable executable
+                  // stage evaluator.
+                  FunctionSpec.newBuilder()
+                      .setUrn(SplittableRemoteStageEvaluatorFactory.URN)
+                      .build())
+              .putInputs(
+                  // The splittable executable stage now reads from the raw GBK, instead of
+                  // from the now non-existent FEED_SDF.
+                  Iterables.getOnlyElement(node.getTransform().getInputsMap().keySet()),
+                  rawGBKOutput.getId())
+              .build());
+    }
+
+    return newPipeline.build();
+  }
+
   private enum EnvironmentType {
     DOCKER,
     IN_PROCESS
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/SplittableRemoteStageEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/SplittableRemoteStageEvaluatorFactory.java
new file mode 100644
index 0000000..fb424c5
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/SplittableRemoteStageEvaluatorFactory.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct.portable;
+
+import com.google.common.collect.Iterables;
+import java.util.ArrayList;
+import java.util.Collection;
+import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
+import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
+import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
+import org.apache.beam.runners.fnexecution.control.RemoteBundle;
+import org.apache.beam.runners.fnexecution.splittabledofn.SDFFeederViaStateAndTimers;
+import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
+import org.apache.beam.runners.fnexecution.wire.WireCoders;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+
+/**
+ * The {@link TransformEvaluatorFactory} for {@link #URN}, which reads from a {@link
+ * DirectGroupByKey#DIRECT_GBKO_URN} and feeds the data, using state and timers, to a {@link
+ * ExecutableStage} whose first instruction is an SDF.
+ */
+class SplittableRemoteStageEvaluatorFactory implements TransformEvaluatorFactory {
+  public static final String URN = "urn:beam:directrunner:transforms:splittable_remote_stage:v1";
+
+  // A fictional transform that transforms from KWI<unique key, KV<element, restriction>>
+  // to simply KV<element, restriction> taken by the SDF inside the ExecutableStage.
+  public static final String FEED_SDF_URN = "urn:beam:directrunner:transforms:feed_sdf:v1";
+
+  private final BundleFactory bundleFactory;
+  private final JobBundleFactory jobBundleFactory;
+  private final StepStateAndTimers.Provider stp;
+
+  SplittableRemoteStageEvaluatorFactory(
+      BundleFactory bundleFactory,
+      JobBundleFactory jobBundleFactory,
+      StepStateAndTimers.Provider stepStateAndTimers) {
+    this.bundleFactory = bundleFactory;
+    this.jobBundleFactory = jobBundleFactory;
+    this.stp = stepStateAndTimers;
+  }
+
+  @Nullable
+  @Override
+  public <InputT> TransformEvaluator<InputT> forApplication(
+      PTransformNode application, CommittedBundle<?> inputBundle) throws Exception {
+    return new SplittableRemoteStageEvaluator(
+        bundleFactory,
+        jobBundleFactory,
+        stp.forStepAndKey(application, inputBundle.getKey()),
+        application);
+  }
+
+  @Override
+  public void cleanup() throws Exception {
+    jobBundleFactory.close();
+  }
+
+  private static class SplittableRemoteStageEvaluator<InputT, RestrictionT>
+      implements TransformEvaluator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> {
+    private final PTransformNode transform;
+    private final ExecutableStage stage;
+
+    private final CopyOnAccessInMemoryStateInternals<byte[]> stateInternals;
+    private final DirectTimerInternals timerInternals;
+    private final RemoteBundle<KV<InputT, RestrictionT>> bundle;
+    private final Collection<UncommittedBundle<?>> outputs;
+
+    private final SDFFeederViaStateAndTimers<InputT, RestrictionT> feeder;
+
+    private SplittableRemoteStageEvaluator(
+        BundleFactory bundleFactory,
+        JobBundleFactory jobBundleFactory,
+        StepStateAndTimers<byte[]> stp,
+        PTransformNode transform)
+        throws Exception {
+      this.stateInternals = stp.stateInternals();
+      this.timerInternals = stp.timerInternals();
+      this.transform = transform;
+      this.stage =
+          ExecutableStage.fromPayload(
+              ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload()));
+      this.outputs = new ArrayList<>();
+      this.bundle =
+          jobBundleFactory
+              .<KV<InputT, RestrictionT>>forStage(stage)
+              .getBundle(
+                  BundleFactoryOutputReceiverFactory.create(
+                      bundleFactory, stage.getComponents(), outputs::add),
+                  StateRequestHandler.unsupported(),
+                  new BundleProgressHandler() {
+                    @Override
+                    public void onProgress(ProcessBundleProgressResponse progress) {
+                      if (progress.hasSplit()) {
+                        feeder.split(progress.getSplit());
+                      }
+                    }
+
+                    @Override
+                    public void onCompleted(ProcessBundleResponse response) {
+                      if (response.hasSplit()) {
+                        feeder.split(response.getSplit());
+                      }
+                    }
+                  });
+
+      FullWindowedValueCoder<KV<InputT, RestrictionT>> windowedValueCoder =
+          (FullWindowedValueCoder<KV<InputT, RestrictionT>>)
+              WireCoders.<KV<InputT, RestrictionT>>instantiateRunnerWireCoder(
+                  stage.getInputPCollection(), stage.getComponents());
+      KvCoder<InputT, RestrictionT> kvCoder =
+          ((KvCoder<InputT, RestrictionT>) windowedValueCoder.getValueCoder());
+
+      this.feeder =
+          new SDFFeederViaStateAndTimers<>(
+              stateInternals,
+              timerInternals,
+              kvCoder.getKeyCoder(),
+              kvCoder.getValueCoder(),
+              (Coder<BoundedWindow>) windowedValueCoder.getWindowCoder());
+    }
+
+    @Override
+    public void processElement(
+        WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> windowedWorkItem)
+        throws Exception {
+      KeyedWorkItem<byte[], KV<InputT, RestrictionT>> kwi = windowedWorkItem.getValue();
+      WindowedValue<KV<InputT, RestrictionT>> elementRestriction =
+          Iterables.getOnlyElement(kwi.elementsIterable(), null);
+      if (elementRestriction != null) {
+        feeder.seed(elementRestriction);
+      } else {
+        elementRestriction = feeder.resume(Iterables.getOnlyElement(kwi.timersIterable()));
+      }
+      bundle.getInputReceiver().accept(elementRestriction);
+    }
+
+    @Override
+    public TransformResult<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> finishBundle()
+        throws Exception {
+      bundle.close();
+      feeder.commit();
+      CopyOnAccessInMemoryStateInternals<byte[]> state = stateInternals.commit();
+      StepTransformResult.Builder<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> result =
+          StepTransformResult.withHold(transform, state.getEarliestWatermarkHold());
+      return result
+          .addOutput(outputs)
+          .withState(state)
+          .withTimerUpdate(timerInternals.getTimerUpdate())
+          .build();
+    }
+  }
+}
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java
index a27bc07..4a502c7 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java
@@ -67,6 +67,10 @@
             .put(
                 ExecutableStage.URN,
                 new RemoteStageEvaluatorFactory(bundleFactory, jobBundleFactory))
+            .put(
+                SplittableRemoteStageEvaluatorFactory.URN,
+                new SplittableRemoteStageEvaluatorFactory(
+                    bundleFactory, jobBundleFactory, stepStateAndTimers))
             .build());
   }
 
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java
index 2599a89..6855fca 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java
@@ -18,6 +18,8 @@
 
 package org.apache.beam.runners.direct.portable;
 
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume;
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
@@ -25,19 +27,28 @@
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import java.io.Serializable;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Set;
 import org.apache.beam.runners.core.construction.JavaReadViaImpulse;
+import org.apache.beam.runners.core.construction.PTransformMatchers;
 import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.range.OffsetRange;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Reshuffle;
+import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
@@ -65,12 +76,12 @@
                 ParDo.of(
                         new DoFn<Integer, KV<String, Integer>>() {
                           @ProcessElement
-                          public void process(@Element Integer e,
-                                              MultiOutputReceiver r) {
+                          public void process(@Element Integer e, MultiOutputReceiver r) {
                             for (int i = 0; i < e; i++) {
-                              r.get(food).outputWithTimestamp(
-                                  KV.of("foo", e),
-                                  new Instant(0).plus(Duration.standardHours(i)));
+                              r.get(food)
+                                  .outputWithTimestamp(
+                                      KV.of("foo", e),
+                                      new Instant(0).plus(Duration.standardHours(i)));
                             }
                             r.get(originals).output(e);
                           }
@@ -86,10 +97,10 @@
                 ParDo.of(
                     new DoFn<KV<String, Iterable<Integer>>, KV<String, Set<Integer>>>() {
                       @ProcessElement
-                      public void process(@Element KV<String, Iterable<Integer>> e,
-                                          OutputReceiver<KV<String, Set<Integer>>> r) {
-                        r.output(
-                            KV.of(e.getKey(), ImmutableSet.copyOf(e.getValue())));
+                      public void process(
+                          @Element KV<String, Iterable<Integer>> e,
+                          OutputReceiver<KV<String, Set<Integer>>> r) {
+                        r.output(KV.of(e.getKey(), ImmutableSet.copyOf(e.getValue())));
                       }
                     }));
 
@@ -136,4 +147,66 @@
             PipelineOptionsTranslation.toProto(PipelineOptionsFactory.create()));
     runner.execute();
   }
+
+  static class PairStringWithIndexToLength extends DoFn<String, KV<String, Integer>> {
+    @ProcessElement
+    public ProcessContinuation process(ProcessContext c, OffsetRangeTracker tracker) {
+      for (long i = tracker.currentRestriction().getFrom(), numIterations = 0;
+          tracker.tryClaim(i);
+          ++i, ++numIterations) {
+        c.output(KV.of(c.element(), (int) i));
+        if (numIterations % 3 == 0) {
+          return resume();
+        }
+      }
+      return stop();
+    }
+
+    @GetInitialRestriction
+    public OffsetRange getInitialRange(String element) {
+      return new OffsetRange(0, element.length());
+    }
+
+    @SplitRestriction
+    public void splitRange(
+        String element, OffsetRange range, OutputReceiver<OffsetRange> receiver) {
+      long middle = (range.getFrom() + range.getTo()) / 2;
+      receiver.output(new OffsetRange(range.getFrom(), middle));
+      receiver.output(new OffsetRange(middle, range.getTo()));
+    }
+  }
+
+  @Test
+  public void testSDF() throws Exception {
+    Pipeline p = Pipeline.create();
+
+    PCollection<KV<String, Integer>> res =
+        p.apply(Create.of("a", "bb", "ccccc"))
+            .apply(ParDo.of(new PairStringWithIndexToLength()))
+            .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()));
+
+    PAssert.that(res)
+        .containsInAnyOrder(
+            Arrays.asList(
+                KV.of("a", 0),
+                KV.of("bb", 0),
+                KV.of("bb", 1),
+                KV.of("ccccc", 0),
+                KV.of("ccccc", 1),
+                KV.of("ccccc", 2),
+                KV.of("ccccc", 3),
+                KV.of("ccccc", 4)));
+
+    p.replaceAll(
+        Arrays.asList(
+            JavaReadViaImpulse.boundedOverride(),
+            PTransformOverride.of(
+                PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory())));
+
+    ReferenceRunner runner =
+        ReferenceRunner.forInProcessPipeline(
+            PipelineTranslation.toProto(p),
+            PipelineOptionsTranslation.toProto(PipelineOptionsFactory.create()));
+    runner.execute();
+  }
 }
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
index df008a8..be9abc3 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
@@ -60,14 +60,14 @@
  */
 public class SplittableDoFnOperator<
         InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
-    extends DoFnOperator<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> {
+    extends DoFnOperator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
 
   private transient ScheduledExecutorService executorService;
 
   public SplittableDoFnOperator(
-      DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> doFn,
+      DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> doFn,
       String stepName,
-      Coder<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>> inputCoder,
+      Coder<WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>> inputCoder,
       TupleTag<OutputT> mainOutputTag,
       List<TupleTag<?>> additionalOutputTags,
       OutputManagerFactory<OutputT> outputManagerFactory,
@@ -76,7 +76,7 @@
       Collection<PCollectionView<?>> sideInputs,
       PipelineOptions options,
       Coder<?> keyCoder,
-      KeySelector<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>, ?> keySelector) {
+      KeySelector<WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, ?> keySelector) {
     super(
         doFn,
         stepName,
@@ -94,8 +94,8 @@
 
   @Override
   protected DoFnRunner<
-      KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> createWrappingDoFnRunner(
-          DoFnRunner<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> wrappedRunner) {
+      KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> createWrappingDoFnRunner(
+          DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> wrappedRunner) {
     // don't wrap in anything because we don't need state cleanup because ProcessFn does
     // all that
     return wrappedRunner;
@@ -162,7 +162,7 @@
     doFnRunner.processElement(
         WindowedValue.valueInGlobalWindow(
             KeyedWorkItems.timersWorkItem(
-                (String) keyedStateInternals.getKey(),
+                (byte[]) keyedStateInternals.getKey(),
                 Collections.singletonList(timer.getNamespace()))));
   }
 
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java
new file mode 100644
index 0000000..9af2605
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.splittabledofn;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.List;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaces;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.TimerInternals.TimerData;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * Helper class for feeding element/restricton pairs into a {@link
+ * PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} transform, implementing checkpointing
+ * only, by using state and timers for storing the last element/restriction pair, similarly to
+ * {@link org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessFn} but in a portable
+ * fashion.
+ */
+public class SDFFeederViaStateAndTimers<InputT, RestrictionT> {
+  private final Coder<BoundedWindow> windowCoder;
+  private final Coder<WindowedValue<KV<InputT, RestrictionT>>> elementRestrictionWireCoder;
+
+  private final StateInternals stateInternals;
+  private final TimerInternals timerInternals;
+
+  private StateNamespace stateNamespace;
+
+  private final StateTag<ValueState<WindowedValue<KV<InputT, RestrictionT>>>> seedTag;
+  private ValueState<WindowedValue<KV<InputT, RestrictionT>>> seedState;
+
+  private final StateTag<ValueState<RestrictionT>> restrictionTag;
+  private ValueState<RestrictionT> restrictionState;
+
+  private StateTag<WatermarkHoldState> watermarkHoldTag =
+      StateTags.makeSystemTagInternal(
+          StateTags.<GlobalWindow>watermarkStateInternal("hold", TimestampCombiner.LATEST));
+  private WatermarkHoldState holdState;
+
+  private Instant inputTimestamp;
+  private BundleSplit split;
+
+  /** Initializes the feeder. */
+  public SDFFeederViaStateAndTimers(
+      StateInternals stateInternals,
+      TimerInternals timerInternals,
+      Coder<InputT> elementWireCoder,
+      Coder<RestrictionT> restrictionWireCoder,
+      Coder<BoundedWindow> windowCoder) {
+    this.stateInternals = stateInternals;
+    this.timerInternals = timerInternals;
+    this.windowCoder = windowCoder;
+    this.elementRestrictionWireCoder =
+        FullWindowedValueCoder.of(KvCoder.of(elementWireCoder, restrictionWireCoder), windowCoder);
+    this.seedTag = StateTags.value("seed", elementRestrictionWireCoder);
+    this.restrictionTag = StateTags.value("restriction", restrictionWireCoder);
+  }
+
+  /** Passes the initial element/restriction pair. */
+  public void seed(WindowedValue<KV<InputT, RestrictionT>> elementRestriction) {
+    initState(
+        StateNamespaces.window(
+            windowCoder, Iterables.getOnlyElement(elementRestriction.getWindows())));
+    seedState.write(elementRestriction);
+    inputTimestamp = elementRestriction.getTimestamp();
+  }
+
+  /**
+   * Resumes from a timer and returns the current element/restriction pair (with an up-to-date value
+   * of the restriction).
+   */
+  public WindowedValue<KV<InputT, RestrictionT>> resume(TimerData timer) {
+    initState(timer.getNamespace());
+    WindowedValue<KV<InputT, RestrictionT>> seed = seedState.read();
+    inputTimestamp = seed.getTimestamp();
+    return seed.withValue(KV.of(seed.getValue().getKey(), restrictionState.read()));
+  }
+
+  /**
+   * Commits the state and timers: clears both if no checkpoint happened, or adjusts the restriction
+   * and sets a wake-up timer if a checkpoint happened.
+   */
+  public void commit() throws IOException {
+    if (split == null) {
+      // No split - the call terminated.
+      seedState.clear();
+      restrictionState.clear();
+      holdState.clear();
+      return;
+    }
+
+    // For now can only happen on the first instruction which is SPLITTABLE_PROCESS_ELEMENTS.
+    List<DelayedApplication> residuals = split.getResidualRootsList();
+    checkArgument(residuals.size() == 1, "More than 1 residual is unsupported for now");
+    DelayedApplication residual = residuals.get(0);
+
+    ByteString encodedResidual = residual.getApplication().getElement();
+    WindowedValue<KV<InputT, RestrictionT>> decodedResidual =
+        CoderUtils.decodeFromByteArray(elementRestrictionWireCoder, encodedResidual.toByteArray());
+
+    restrictionState.write(decodedResidual.getValue().getValue());
+
+    Instant watermarkHold =
+        residual.getApplication().getOutputWatermarksMap().isEmpty()
+            ? inputTimestamp
+            : new Instant(
+                Iterables.getOnlyElement(
+                    residual.getApplication().getOutputWatermarksMap().values()));
+    checkArgument(
+        !watermarkHold.isBefore(inputTimestamp),
+        "Watermark hold %s can not be before input timestamp %s",
+        watermarkHold,
+        inputTimestamp);
+    holdState.add(watermarkHold);
+
+    Duration resumeDelay = new Duration((long) (1000L * residual.getDelaySec()));
+    Instant wakeupTime = timerInternals.currentProcessingTime().plus(resumeDelay);
+
+    // Set a timer to continue processing this element.
+    timerInternals.setTimer(
+        stateNamespace, "sdfContinuation", wakeupTime, TimeDomain.PROCESSING_TIME);
+  }
+
+  /** Signals that a split happened. */
+  public void split(BundleSplit split) {
+    checkState(
+        this.split == null,
+        "At most 1 split supported, however got new split %s in addition to existing %s",
+        split,
+        this.split);
+    this.split = split;
+  }
+
+  private void initState(StateNamespace ns) {
+    stateNamespace = ns;
+    seedState = stateInternals.state(ns, seedTag);
+    restrictionState = stateInternals.state(ns, restrictionTag);
+    holdState = stateInternals.state(ns, watermarkHoldTag);
+  }
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/package-info.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/package-info.java
new file mode 100644
index 0000000..e50f14d
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Utilities for a Beam runner to interact with a remotely running splittable DoFn. */
+package org.apache.beam.runners.fnexecution.splittabledofn;
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index db25581..b2f8195 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -28,6 +28,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -93,7 +94,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) throws IOException {
 
       BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
           .setPrimitiveTransformReference(pTransformId)
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
index 34d704d..a4f6e54 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
@@ -27,6 +27,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -84,7 +85,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) throws IOException {
       BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
           .setPrimitiveTransformReference(pTransformId)
           .setName(getOnlyElement(pTransform.getInputsMap().keySet()))
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
index 5df178f..185fc11 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
@@ -28,6 +28,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -81,7 +82,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) {
 
       ImmutableList.Builder<FnDataReceiver<WindowedValue<?>>> consumers = ImmutableList.builder();
       for (String pCollectionId : pTransform.getOutputsMap().values()) {
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
new file mode 100644
index 0000000..9cbe34c
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.Sets;
+import java.io.IOException;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.state.BeamFnStateClient;
+import org.apache.beam.fn.harness.state.SideInputSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+
+/** A {@link PTransformRunnerFactory} for transforms invoking a {@link DoFn}. */
+abstract class DoFnPTransformRunnerFactory<
+        TransformInputT,
+        FnInputT,
+        OutputT,
+        RunnerT extends DoFnPTransformRunnerFactory.DoFnPTransformRunner<TransformInputT>>
+    implements PTransformRunnerFactory<RunnerT> {
+  interface DoFnPTransformRunner<T> {
+    void startBundle() throws Exception;
+
+    void processElement(WindowedValue<T> input) throws Exception;
+
+    void finishBundle() throws Exception;
+  }
+
+  @Override
+  public final RunnerT createRunnerForPTransform(
+      PipelineOptions pipelineOptions,
+      BeamFnDataClient beamFnDataClient,
+      BeamFnStateClient beamFnStateClient,
+      String ptransformId,
+      PTransform pTransform,
+      Supplier<String> processBundleInstructionId,
+      Map<String, PCollection> pCollections,
+      Map<String, RunnerApi.Coder> coders,
+      Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+      Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+      Consumer<ThrowingRunnable> addStartFunction,
+      Consumer<ThrowingRunnable> addFinishFunction,
+      BundleSplitListener splitListener) {
+    Context<FnInputT, OutputT> context =
+        new Context<>(
+            pipelineOptions,
+            beamFnStateClient,
+            ptransformId,
+            pTransform,
+            processBundleInstructionId,
+            pCollections,
+            coders,
+            windowingStrategies,
+            pCollectionIdsToConsumers,
+            splitListener);
+
+    RunnerT runner = createRunner(context);
+
+    // Register the appropriate handlers.
+    addStartFunction.accept(runner::startBundle);
+    Iterable<String> mainInput =
+        Sets.difference(
+            pTransform.getInputsMap().keySet(), context.parDoPayload.getSideInputsMap().keySet());
+    for (String localInputName : mainInput) {
+      pCollectionIdsToConsumers.put(
+          pTransform.getInputsOrThrow(localInputName),
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<TransformInputT>>) runner::processElement);
+    }
+    addFinishFunction.accept(runner::finishBundle);
+    return runner;
+  }
+
+  abstract RunnerT createRunner(Context<FnInputT, OutputT> context);
+
+  static class Context<InputT, OutputT> {
+    final PipelineOptions pipelineOptions;
+    final BeamFnStateClient beamFnStateClient;
+    final String ptransformId;
+    final PTransform pTransform;
+    final Supplier<String> processBundleInstructionId;
+    final RehydratedComponents rehydratedComponents;
+    final DoFn<InputT, OutputT> doFn;
+    final TupleTag<OutputT> mainOutputTag;
+    final Coder<?> inputCoder;
+    final Coder<?> keyCoder;
+    final Coder<? extends BoundedWindow> windowCoder;
+    final WindowingStrategy<InputT, ?> windowingStrategy;
+    final Map<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap;
+    final ParDoPayload parDoPayload;
+    final ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer;
+    final BundleSplitListener splitListener;
+
+    Context(
+        PipelineOptions pipelineOptions,
+        BeamFnStateClient beamFnStateClient,
+        String ptransformId,
+        PTransform pTransform,
+        Supplier<String> processBundleInstructionId,
+        Map<String, PCollection> pCollections,
+        Map<String, RunnerApi.Coder> coders,
+        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+        Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+        BundleSplitListener splitListener) {
+      this.pipelineOptions = pipelineOptions;
+      this.beamFnStateClient = beamFnStateClient;
+      this.ptransformId = ptransformId;
+      this.pTransform = pTransform;
+      this.processBundleInstructionId = processBundleInstructionId;
+      ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMapBuilder =
+          ImmutableMap.builder();
+      try {
+        rehydratedComponents =
+            RehydratedComponents.forComponents(
+                RunnerApi.Components.newBuilder()
+                    .putAllCoders(coders)
+                    .putAllWindowingStrategies(windowingStrategies)
+                    .build());
+        parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
+        doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
+        mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
+        String mainInputTag =
+            Iterables.getOnlyElement(
+                Sets.difference(
+                    pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
+        PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
+        inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());
+        if (inputCoder instanceof KvCoder
+            // TODO: Stop passing windowed value coders within PCollections.
+            || (inputCoder instanceof WindowedValue.WindowedValueCoder
+                && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) {
+          this.keyCoder =
+              inputCoder instanceof WindowedValueCoder
+                  ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
+                  : ((KvCoder) inputCoder).getKeyCoder();
+        } else {
+          this.keyCoder = null;
+        }
+
+        windowingStrategy =
+            (WindowingStrategy)
+                rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId());
+        windowCoder = windowingStrategy.getWindowFn().windowCoder();
+
+        // Build the map from tag id to side input specification
+        for (Map.Entry<String, RunnerApi.SideInput> entry :
+            parDoPayload.getSideInputsMap().entrySet()) {
+          String sideInputTag = entry.getKey();
+          RunnerApi.SideInput sideInput = entry.getValue();
+          checkArgument(
+              Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+                  sideInput.getAccessPattern().getUrn()),
+              "This SDK is only capable of dealing with %s materializations "
+                  + "but was asked to handle %s for PCollectionView with tag %s.",
+              Materializations.MULTIMAP_MATERIALIZATION_URN,
+              sideInput.getAccessPattern().getUrn(),
+              sideInputTag);
+
+          PCollection sideInputPCollection =
+              pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
+          WindowingStrategy sideInputWindowingStrategy =
+              rehydratedComponents.getWindowingStrategy(
+                  sideInputPCollection.getWindowingStrategyId());
+          tagToSideInputSpecMapBuilder.put(
+              new TupleTag<>(entry.getKey()),
+              SideInputSpec.create(
+                  rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
+                  sideInputWindowingStrategy.getWindowFn().windowCoder(),
+                  PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
+                  PCollectionViewTranslation.windowMappingFnFromProto(
+                      entry.getValue().getWindowMappingFn())));
+        }
+      } catch (IOException exn) {
+        throw new IllegalArgumentException("Malformed ParDoPayload", exn);
+      }
+
+      ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
+          tagToConsumerBuilder = ImmutableListMultimap.builder();
+      for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
+        tagToConsumerBuilder.putAll(
+            new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
+      }
+      tagToConsumer = tagToConsumerBuilder.build();
+      tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build();
+      this.splitListener = splitListener;
+    }
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
index 52d9435..6ac860e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
@@ -27,6 +27,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -69,7 +70,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction)
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener)
         throws IOException {
 
       // Give each input a MultiplexingFnDataReceiver to all outputs of the flatten.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 9778aac..4cfb5d9 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -17,75 +17,28 @@
  */
 package org.apache.beam.fn.harness;
 
-import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
-import static com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.service.AutoService;
-import com.google.auto.value.AutoValue;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableListMultimap;
 import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.ListMultimap;
-import com.google.common.collect.Multimap;
-import com.google.common.collect.Sets;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.InvalidProtocolBufferException;
-import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
-import java.util.Set;
-import java.util.function.Consumer;
-import java.util.function.Function;
-import java.util.function.Supplier;
-import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.state.BagUserState;
-import org.apache.beam.fn.harness.state.BeamFnStateClient;
-import org.apache.beam.fn.harness.state.MultimapSideInput;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
-import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
 import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.core.construction.RehydratedComponents;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.sdk.fn.function.ThrowingRunnable;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.state.BagState;
-import org.apache.beam.sdk.state.CombiningState;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
-import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.State;
-import org.apache.beam.sdk.state.StateBinder;
-import org.apache.beam.sdk.state.StateContext;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.state.Timer;
-import org.apache.beam.sdk.state.ValueState;
-import org.apache.beam.sdk.state.WatermarkHoldState;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
 import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
-import org.apache.beam.sdk.transforms.DoFn.StartBundleContext;
 import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
-import org.apache.beam.sdk.transforms.Materializations;
-import org.apache.beam.sdk.transforms.ViewFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -94,399 +47,145 @@
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
-import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
-import org.apache.beam.sdk.util.CombineFnUtil;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
-import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
 import org.joda.time.Instant;
 
 /**
- * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers
- * of abstraction caused by StateInternals/TimerInternals since they model state and timer
- * concepts differently.
+ * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers of
+ * abstraction caused by StateInternals/TimerInternals since they model state and timer concepts
+ * differently.
  */
-public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
-
-  private final ProcessBundleContext processContext;
-  private final FinishBundleContext finishBundleContext;
-  private StartBundleContext startBundleContext;
-
-  /**
-   * A registrar which provides a factory to handle Java {@link DoFn}s.
-   */
+public class FnApiDoFnRunner<InputT, OutputT>
+    implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<InputT> {
+  /** A registrar which provides a factory to handle Java {@link DoFn}s. */
   @AutoService(PTransformRunnerFactory.Registrar.class)
-  public static class Registrar implements
-      PTransformRunnerFactory.Registrar {
-
+  public static class Registrar implements PTransformRunnerFactory.Registrar {
     @Override
     public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
-      return ImmutableMap.of(
-          PTransformTranslation.PAR_DO_TRANSFORM_URN, new NewFactory(),
-          ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory());
+      return ImmutableMap.of(PTransformTranslation.PAR_DO_TRANSFORM_URN, new Factory());
     }
   }
 
-  /** A factory for {@link FnApiDoFnRunner}. */
   static class Factory<InputT, OutputT>
-      implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
+      extends DoFnPTransformRunnerFactory<
+          InputT, InputT, OutputT, FnApiDoFnRunner<InputT, OutputT>> {
     @Override
-    public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
-        PipelineOptions pipelineOptions,
-        BeamFnDataClient beamFnDataClient,
-        BeamFnStateClient beamFnStateClient,
-        String ptransformId,
-        PTransform pTransform,
-        Supplier<String> processBundleInstructionId,
-        Map<String, PCollection> pCollections,
-        Map<String, RunnerApi.Coder> coders,
-        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
-        Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
-        Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
-
-      // For every output PCollection, create a map from output name to Consumer
-      ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
-          tagToOutputMapBuilder = ImmutableListMultimap.builder();
-      for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
-        tagToOutputMapBuilder.putAll(
-            new TupleTag<>(entry.getKey()),
-            pCollectionIdsToConsumers.get(entry.getValue()));
-      }
-      ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToOutputMap =
-          tagToOutputMapBuilder.build();
-
-      // Get the DoFnInfo from the serialized blob.
-      ByteString serializedFn = pTransform.getSpec().getPayload();
-      @SuppressWarnings({"unchecked", "rawtypes"})
-      DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray(
-          serializedFn.toByteArray(), "DoFnInfo");
-
-      @SuppressWarnings({"unchecked", "rawtypes"})
-      DoFnRunner<InputT, OutputT> runner =
-          new FnApiDoFnRunner<>(
-              pipelineOptions,
-              beamFnStateClient,
-              ptransformId,
-              processBundleInstructionId,
-              doFnInfo.getDoFn(),
-              doFnInfo.getInputCoder(),
-              (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
-                  (Collection) tagToOutputMap.get(doFnInfo.getMainOutput()),
-              tagToOutputMap,
-              ImmutableMap.of(),
-              doFnInfo.getWindowingStrategy());
-
-      registerHandlers(
-          runner,
-          pTransform,
-          ImmutableSet.of(),
-          addStartFunction,
-          addFinishFunction,
-          pCollectionIdsToConsumers);
-      return runner;
+    public FnApiDoFnRunner<InputT, OutputT> createRunner(Context<InputT, OutputT> context) {
+      return new FnApiDoFnRunner<>(context);
     }
   }
 
-  static class NewFactory<InputT, OutputT>
-      implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
-    @Override
-    public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
-        PipelineOptions pipelineOptions,
-        BeamFnDataClient beamFnDataClient,
-        BeamFnStateClient beamFnStateClient,
-        String ptransformId,
-        RunnerApi.PTransform pTransform,
-        Supplier<String> processBundleInstructionId,
-        Map<String, RunnerApi.PCollection> pCollections,
-        Map<String, RunnerApi.Coder> coders,
-        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
-        Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
-        Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
-
-      DoFn<InputT, OutputT> doFn;
-      TupleTag<OutputT> mainOutputTag;
-      Coder<InputT> inputCoder;
-      WindowingStrategy<InputT, ?> windowingStrategy;
-
-      ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap =
-          ImmutableMap.builder();
-      ParDoPayload parDoPayload;
-      try {
-        RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(
-            RunnerApi.Components.newBuilder()
-                .putAllCoders(coders).putAllWindowingStrategies(windowingStrategies).build());
-        parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
-        doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
-        mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
-        String mainInputTag = Iterables.getOnlyElement(Sets.difference(
-            pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
-        RunnerApi.PCollection mainInput =
-            pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
-        inputCoder = (Coder<InputT>) rehydratedComponents.getCoder(
-            mainInput.getCoderId());
-        windowingStrategy = (WindowingStrategy) rehydratedComponents.getWindowingStrategy(
-            mainInput.getWindowingStrategyId());
-
-        // Build the map from tag id to side input specification
-        for (Map.Entry<String, RunnerApi.SideInput> entry
-            : parDoPayload.getSideInputsMap().entrySet()) {
-          String sideInputTag = entry.getKey();
-          RunnerApi.SideInput sideInput = entry.getValue();
-          checkArgument(
-              Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
-                  sideInput.getAccessPattern().getUrn()),
-              "This SDK is only capable of dealing with %s materializations "
-                  + "but was asked to handle %s for PCollectionView with tag %s.",
-              Materializations.MULTIMAP_MATERIALIZATION_URN,
-              sideInput.getAccessPattern().getUrn(),
-              sideInputTag);
-
-          RunnerApi.PCollection sideInputPCollection =
-              pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
-          WindowingStrategy sideInputWindowingStrategy =
-              rehydratedComponents.getWindowingStrategy(
-                  sideInputPCollection.getWindowingStrategyId());
-          tagToSideInputSpecMap.put(
-              new TupleTag<>(entry.getKey()),
-              SideInputSpec.create(
-                  rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
-                  sideInputWindowingStrategy.getWindowFn().windowCoder(),
-                  PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
-                  PCollectionViewTranslation.windowMappingFnFromProto(
-                      entry.getValue().getWindowMappingFn())));
-        }
-      } catch (InvalidProtocolBufferException exn) {
-        throw new IllegalArgumentException("Malformed ParDoPayload", exn);
-      } catch (IOException exn) {
-        throw new IllegalArgumentException("Malformed ParDoPayload", exn);
-      }
-
-      ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
-          tagToConsumerBuilder = ImmutableListMultimap.builder();
-      for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
-        tagToConsumerBuilder.putAll(
-            new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
-      }
-      ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer =
-          tagToConsumerBuilder.build();
-
-      @SuppressWarnings({"unchecked", "rawtypes"})
-      DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>(
-          pipelineOptions,
-          beamFnStateClient,
-          ptransformId,
-          processBundleInstructionId,
-          doFn,
-          inputCoder,
-          (Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection)
-              tagToConsumer.get(mainOutputTag),
-          tagToConsumer,
-          tagToSideInputSpecMap.build(),
-          windowingStrategy);
-      registerHandlers(
-          runner,
-          pTransform,
-          parDoPayload.getSideInputsMap().keySet(),
-          addStartFunction,
-          addFinishFunction,
-          pCollectionIdsToConsumers);
-      return runner;
-    }
-  }
-
-  private static <InputT, OutputT> void registerHandlers(
-      DoFnRunner<InputT, OutputT> runner,
-      RunnerApi.PTransform pTransform,
-      Set<String> sideInputLocalNames,
-      Consumer<ThrowingRunnable> addStartFunction,
-      Consumer<ThrowingRunnable> addFinishFunction,
-      Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers) {
-    // Register the appropriate handlers.
-    addStartFunction.accept(runner::startBundle);
-    for (String localInputName
-        : Sets.difference(pTransform.getInputsMap().keySet(), sideInputLocalNames)) {
-      pCollectionIdsToConsumers.put(
-          pTransform.getInputsOrThrow(localInputName),
-          (FnDataReceiver) (FnDataReceiver<WindowedValue<InputT>>) runner::processElement);
-    }
-    addFinishFunction.accept(runner::finishBundle);
-  }
-
   //////////////////////////////////////////////////////////////////////////////////////////////////
 
-  private final PipelineOptions pipelineOptions;
-  private final BeamFnStateClient beamFnStateClient;
-  private final String ptransformId;
-  private final Supplier<String> processBundleInstructionId;
-  private final DoFn<InputT, OutputT> doFn;
-  private final Coder<InputT> inputCoder;
+  private final Context<InputT, OutputT> context;
   private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
-  private final Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap;
-  private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
-  private final Map<StateKey, Object> stateKeyObjectCache;
-  private final WindowingStrategy windowingStrategy;
+  private FnApiStateAccessor stateAccessor;
   private final DoFnSignature doFnSignature;
   private final DoFnInvoker<InputT, OutputT> doFnInvoker;
-  private final StateBinder stateBinder;
-  private final Collection<ThrowingRunnable> stateFinalizers;
 
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and is null otherwise.
-   */
+  private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+  private final ProcessBundleContext processContext;
+  private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+  /** Only valid during {@link #processElement}, null otherwise. */
   private WindowedValue<InputT> currentElement;
 
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and is null otherwise.
-   */
   private BoundedWindow currentWindow;
 
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and only when processing a {@link KV} and is null otherwise.
-   */
-  private ByteString encodedCurrentKey;
-
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and is null otherwise.
-   */
-  private ByteString encodedCurrentWindow;
-
-  FnApiDoFnRunner(
-      PipelineOptions pipelineOptions,
-      BeamFnStateClient beamFnStateClient,
-      String ptransformId,
-      Supplier<String> processBundleInstructionId,
-      DoFn<InputT, OutputT> doFn,
-      Coder<InputT> inputCoder,
-      Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
-      Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap,
-      Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
-      WindowingStrategy windowingStrategy) {
-    this.pipelineOptions = pipelineOptions;
-    this.beamFnStateClient = beamFnStateClient;
-    this.ptransformId = ptransformId;
-    this.processBundleInstructionId = processBundleInstructionId;
-    this.doFn = doFn;
-    this.inputCoder = inputCoder;
-    this.mainOutputConsumers = mainOutputConsumers;
-    this.outputMap = outputMap;
-    this.sideInputSpecMap = sideInputSpecMap;
-    this.stateKeyObjectCache = new HashMap<>();
-    this.windowingStrategy = windowingStrategy;
-    this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
-    this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
+  FnApiDoFnRunner(Context<InputT, OutputT> context) {
+    this.context = context;
+    this.mainOutputConsumers =
+        (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+            (Collection) context.tagToConsumer.get(context.mainOutputTag);
+    this.doFnSignature = DoFnSignatures.signatureForDoFn(context.doFn);
+    this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
     this.doFnInvoker.invokeSetup();
-    this.stateBinder = new BeamFnStateBinder();
-    this.stateFinalizers = new ArrayList<>();
 
-    this.startBundleContext = doFn.new StartBundleContext() {
-      @Override
-      public PipelineOptions getPipelineOptions() {
-        return pipelineOptions;
-      }
-    };
+    this.startBundleContext =
+        this.context.doFn.new StartBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
+        };
     this.processContext = new ProcessBundleContext();
-    finishBundleContext = doFn.new FinishBundleContext() {
-      @Override
-      public PipelineOptions getPipelineOptions() {
-        return pipelineOptions;
-      }
+    finishBundleContext =
+        this.context.doFn.new FinishBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
 
-      @Override
-      public void output(OutputT output, Instant timestamp, BoundedWindow window) {
-        outputTo(mainOutputConsumers,
-            WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
-      }
+          @Override
+          public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+            outputTo(
+                mainOutputConsumers,
+                WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+          }
 
-      @Override
-      public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
-        Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
-        if (consumers == null) {
-          throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
-        }
-        outputTo(consumers,
-            WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
-      }
-    };
+          @Override
+          public <T> void output(
+              TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+            Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+                (Collection) context.tagToConsumer.get(tag);
+            if (consumers == null) {
+              throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
+            }
+            outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+          }
+        };
   }
 
   @Override
   public void startBundle() {
+    this.stateAccessor =
+        new FnApiStateAccessor(
+            context.pipelineOptions,
+            context.ptransformId,
+            context.processBundleInstructionId,
+            context.tagToSideInputSpecMap,
+            context.beamFnStateClient,
+            context.keyCoder,
+            (Coder<BoundedWindow>) context.windowCoder);
+
     doFnInvoker.invokeStartBundle(startBundleContext);
   }
 
   @Override
   public void processElement(WindowedValue<InputT> elem) {
     currentElement = elem;
+    stateAccessor.setCurrentElement(elem);
     try {
       Iterator<BoundedWindow> windowIterator =
           (Iterator<BoundedWindow>) elem.getWindows().iterator();
       while (windowIterator.hasNext()) {
         currentWindow = windowIterator.next();
+        stateAccessor.setCurrentWindow(currentWindow);
         doFnInvoker.invokeProcessElement(processContext);
       }
     } finally {
       currentElement = null;
       currentWindow = null;
-      encodedCurrentKey = null;
-      encodedCurrentWindow = null;
+      stateAccessor.setCurrentElement(null);
+      stateAccessor.setCurrentWindow(null);
     }
   }
 
   @Override
-  public void onTimer(
-      String timerId,
-      BoundedWindow window,
-      Instant timestamp,
-      TimeDomain timeDomain) {
-    throw new UnsupportedOperationException("TODO: Add support for timers");
-  }
-
-  @Override
   public void finishBundle() {
     doFnInvoker.invokeFinishBundle(finishBundleContext);
 
-    // Persist all dirty state cells
-    try {
-      for (ThrowingRunnable runnable : stateFinalizers) {
-        runnable.run();
-      }
-    } catch (InterruptedException e) {
-      Thread.currentThread().interrupt();
-      throw new IllegalStateException(e);
-    } catch (Exception e) {
-      throw new IllegalStateException(e);
-    }
-
     // TODO: Support caching state data across bundle boundaries.
-    stateKeyObjectCache.clear();
+    this.stateAccessor.finalizeState();
+    this.stateAccessor = null;
   }
 
-  @Override
-  public DoFn<InputT, OutputT> getFn() {
-    return doFnInvoker.getFn();
-  }
-
-  /**
-   * Outputs the given element to the specified set of consumers wrapping any exceptions.
-   */
+  /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
   private <T> void outputTo(
-      Collection<FnDataReceiver<WindowedValue<T>>> consumers,
-      WindowedValue<T> output) {
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
     try {
       for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
         consumer.accept(output);
@@ -499,12 +198,11 @@
   /**
    * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}.
    */
-  private class ProcessBundleContext
-      extends DoFn<InputT, OutputT>.ProcessContext
+  private class ProcessBundleContext extends DoFn<InputT, OutputT>.ProcessContext
       implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
 
     private ProcessBundleContext() {
-      doFn.super();
+      context.doFn.super();
     }
 
     @Override
@@ -577,11 +275,11 @@
       checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
       StateSpec<?> spec;
       try {
-        spec = (StateSpec<?>) stateDeclaration.field().get(doFn);
+        spec = (StateSpec<?>) stateDeclaration.field().get(context.doFn);
       } catch (IllegalAccessException e) {
         throw new RuntimeException(e);
       }
-      return spec.bind(stateId, stateBinder);
+      return spec.bind(stateId, stateAccessor);
     }
 
     @Override
@@ -591,60 +289,51 @@
 
     @Override
     public PipelineOptions getPipelineOptions() {
-      return pipelineOptions;
+      return context.pipelineOptions;
     }
 
     @Override
     public PipelineOptions pipelineOptions() {
-      return pipelineOptions;
+      return context.pipelineOptions;
     }
 
     @Override
     public void output(OutputT output) {
-      outputTo(mainOutputConsumers,
+      outputTo(
+          mainOutputConsumers,
           WindowedValue.of(
-              output,
-              currentElement.getTimestamp(),
-              currentWindow,
-              currentElement.getPane()));
+              output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
     }
 
     @Override
     public void outputWithTimestamp(OutputT output, Instant timestamp) {
-      outputTo(mainOutputConsumers,
-          WindowedValue.of(
-              output,
-              timestamp,
-              currentWindow,
-              currentElement.getPane()));
+      outputTo(
+          mainOutputConsumers,
+          WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
     }
 
     @Override
     public <T> void output(TupleTag<T> tag, T output) {
-      Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+          (Collection) context.tagToConsumer.get(tag);
       if (consumers == null) {
         throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
       }
-      outputTo(consumers,
+      outputTo(
+          consumers,
           WindowedValue.of(
-              output,
-              currentElement.getTimestamp(),
-              currentWindow,
-              currentElement.getPane()));
+              output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
     }
 
     @Override
     public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
-      Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+          (Collection) context.tagToConsumer.get(tag);
       if (consumers == null) {
         throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
       }
-      outputTo(consumers,
-          WindowedValue.of(
-              output,
-              timestamp,
-              currentWindow,
-              currentElement.getPane()));
+      outputTo(
+          consumers, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
     }
 
     @Override
@@ -654,7 +343,7 @@
 
     @Override
     public <T> T sideInput(PCollectionView<T> view) {
-      return (T) bindSideInputView(view.getTagInternal());
+      return stateAccessor.get(view, currentWindow);
     }
 
     @Override
@@ -672,359 +361,4 @@
       throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
     }
   }
-
-  /**
-   * A {@link StateBinder} that uses the Beam Fn State API to read and write user state.
-   *
-   * <p>TODO: Add support for {@link #bindMap} and {@link #bindSet}. Note that
-   * {@link #bindWatermark} should never be implemented.
-   */
-  private class BeamFnStateBinder implements StateBinder {
-    @Override
-    public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
-      return (ValueState<T>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          new Function<StateKey, Object>() {
-            @Override
-            public Object apply(StateKey key) {
-              return new ValueState<T>() {
-                private final BagUserState<T> impl = createBagUserState(id, coder);
-
-                @Override
-                public void clear() {
-                  impl.clear();
-                }
-
-                @Override
-                public void write(T input) {
-                  impl.clear();
-                  impl.append(input);
-                }
-
-                @Override
-                public T read() {
-                  Iterator<T> value = impl.get().iterator();
-                  if (value.hasNext()) {
-                    return value.next();
-                  } else {
-                    return null;
-                  }
-                }
-
-                @Override
-                public ValueState<T> readLater() {
-                  // TODO: Support prefetching.
-                  return this;
-                }
-              };
-            }
-          });
-    }
-
-    @Override
-    public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
-      return (BagState<T>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          new Function<StateKey, Object>() {
-            @Override
-            public Object apply(StateKey key) {
-              return new BagState<T>() {
-                private final BagUserState<T> impl = createBagUserState(id, elemCoder);
-
-                @Override
-                public void add(T value) {
-                  impl.append(value);
-                }
-
-                @Override
-                public ReadableState<Boolean> isEmpty() {
-                  return ReadableStates.immediate(!impl.get().iterator().hasNext());
-                }
-
-                @Override
-                public Iterable<T> read() {
-                  return impl.get();
-                }
-
-                @Override
-                public BagState<T> readLater() {
-                  // TODO: Support prefetching.
-                  return this;
-                }
-
-                @Override
-                public void clear() {
-                  impl.clear();
-                }
-              };
-            }
-          });
-    }
-
-    @Override
-    public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
-      throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
-    }
-
-    @Override
-    public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(String id,
-        StateSpec<MapState<KeyT, ValueT>> spec, Coder<KeyT> mapKeyCoder,
-        Coder<ValueT> mapValueCoder) {
-      throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
-    }
-
-    @Override
-    public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
-        String id,
-        StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec, Coder<AccumT> accumCoder,
-        CombineFn<ElementT, AccumT, ResultT> combineFn) {
-      return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          new Function<StateKey, Object>() {
-            @Override
-            public Object apply(StateKey key) {
-              // TODO: Support squashing accumulators depending on whether we know of all
-              // remote accumulators and local accumulators or just local accumulators.
-              return new CombiningState<ElementT, AccumT, ResultT>() {
-                private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
-
-                @Override
-                public AccumT getAccum() {
-                  Iterator<AccumT> iterator = impl.get().iterator();
-                  if (iterator.hasNext()) {
-                    return iterator.next();
-                  }
-                  return combineFn.createAccumulator();
-                }
-
-                @Override
-                public void addAccum(AccumT accum) {
-                  Iterator<AccumT> iterator = impl.get().iterator();
-
-                  // Only merge if there was a prior value
-                  if (iterator.hasNext()) {
-                    accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
-                    // Since there was a prior value, we need to clear.
-                    impl.clear();
-                  }
-
-                  impl.append(accum);
-                }
-
-                @Override
-                public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
-                  return combineFn.mergeAccumulators(accumulators);
-                }
-
-                @Override
-                public CombiningState<ElementT, AccumT, ResultT> readLater() {
-                  return this;
-                }
-
-                @Override
-                public ResultT read() {
-                  Iterator<AccumT> iterator = impl.get().iterator();
-                  if (iterator.hasNext()) {
-                    return combineFn.extractOutput(iterator.next());
-                  }
-                  return combineFn.defaultValue();
-                }
-
-                @Override
-                public void add(ElementT value) {
-                  AccumT newAccumulator = combineFn.addInput(getAccum(), value);
-                  impl.clear();
-                  impl.append(newAccumulator);
-                }
-
-                @Override
-                public ReadableState<Boolean> isEmpty() {
-                  return ReadableStates.immediate(!impl.get().iterator().hasNext());
-                }
-
-                @Override
-                public void clear() {
-                  impl.clear();
-                }
-              };
-            }
-          });
-    }
-
-    @Override
-    public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT>
-    bindCombiningWithContext(
-        String id,
-        StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
-        Coder<AccumT> accumCoder,
-        CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
-      return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          key -> bindCombining(id, spec, accumCoder, CombineFnUtil.bindContext(combineFn,
-              new StateContext<BoundedWindow>() {
-                @Override
-                public PipelineOptions getPipelineOptions() {
-                  return pipelineOptions;
-                }
-
-                @Override
-                public <T> T sideInput(PCollectionView<T> view) {
-                  return (T) bindSideInputView(view.getTagInternal());
-                }
-
-                @Override
-                public BoundedWindow window() {
-                  return currentWindow;
-                }
-              })));
-    }
-
-    /**
-     * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing
-     * and is waiting on resolution of BEAM-2535.
-     */
-    @Override
-    @Deprecated
-    public WatermarkHoldState bindWatermark(String id, StateSpec<WatermarkHoldState> spec,
-        TimestampCombiner timestampCombiner) {
-      throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
-    }
-
-    private <T> BagUserState<T> createBagUserState(
-        String stateId, Coder<T> valueCoder) {
-      BagUserState<T> rval = new BagUserState<>(
-          beamFnStateClient,
-          processBundleInstructionId.get(),
-          ptransformId,
-          stateId,
-          encodedCurrentWindow,
-          encodedCurrentKey,
-          valueCoder);
-      stateFinalizers.add(rval::asyncClose);
-      return rval;
-    }
-  }
-
-  private StateKey createBagUserStateKey(String stateId) {
-    cacheEncodedKeyAndWindowForKeyedContext();
-    StateKey.Builder builder = StateKey.newBuilder();
-    builder.getBagUserStateBuilder()
-        .setWindow(encodedCurrentWindow)
-        .setKey(encodedCurrentKey)
-        .setPtransformId(ptransformId)
-        .setUserStateId(stateId);
-    return builder.build();
-  }
-
-  /**
-   * Memoizes an encoded key and window for the current element being processed saving on the
-   * encoding cost of the key and window across multiple state cells for the lifetime of
-   * {@link #processElement}.
-   *
-   * <p>This should only be called during {@link #processElement}.
-   */
-  private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
-    if (encodedCurrentKey == null) {
-      checkState(currentElement.getValue() instanceof KV,
-          "Accessing state in unkeyed context. Current element is not a KV: %s.",
-          currentElement);
-      checkState(
-          // TODO: Stop passing windowed value coders within PCollections.
-          inputCoder instanceof KvCoder
-              || (inputCoder instanceof WindowedValueCoder
-              && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder)),
-          "Accessing state in unkeyed context. Keyed coder expected but found %s.",
-          inputCoder);
-
-      ByteString.Output encodedKeyOut = ByteString.newOutput();
-
-      Coder<K> keyCoder = inputCoder instanceof WindowedValueCoder
-          ? ((KvCoder<K, ?>) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
-          : ((KvCoder<K, ?>) inputCoder).getKeyCoder();
-      try {
-        keyCoder.encode(((KV<K, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
-      } catch (IOException e) {
-        throw new IllegalStateException(e);
-      }
-      encodedCurrentKey = encodedKeyOut.toByteString();
-    }
-
-    if (encodedCurrentWindow == null) {
-      ByteString.Output encodedWindowOut = ByteString.newOutput();
-      try {
-        windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, encodedWindowOut);
-      } catch (IOException e) {
-        throw new IllegalStateException(e);
-      }
-      encodedCurrentWindow = encodedWindowOut.toByteString();
-    }
-  }
-
-  /**
-   * A specification for side inputs containing a value {@link Coder},
-   * the window {@link Coder}, {@link ViewFn}, and the {@link WindowMappingFn}.
-   * @param <W>
-   */
-  @AutoValue
-  abstract static class SideInputSpec<W extends BoundedWindow> {
-    static <W extends BoundedWindow> SideInputSpec create(
-        Coder<?> coder,
-        Coder<W> windowCoder,
-        ViewFn<?, ?> viewFn,
-        WindowMappingFn<W> windowMappingFn) {
-      return new AutoValue_FnApiDoFnRunner_SideInputSpec<>(
-          coder, windowCoder, viewFn, windowMappingFn);
-    }
-
-    abstract Coder<?> getCoder();
-
-    abstract Coder<W> getWindowCoder();
-
-    abstract ViewFn<?, ?> getViewFn();
-
-    abstract WindowMappingFn<W> getWindowMappingFn();
-  }
-
-  private <K, V> Object bindSideInputView(TupleTag<?> view) {
-    SideInputSpec sideInputSpec = sideInputSpecMap.get(view);
-    checkArgument(sideInputSpec != null,
-        "Attempting to access unknown side input %s.",
-        view);
-    KvCoder<K, V> kvCoder = (KvCoder) sideInputSpec.getCoder();
-
-    ByteString.Output encodedWindowOut = ByteString.newOutput();
-    try {
-      sideInputSpec.getWindowCoder().encode(
-          sideInputSpec.getWindowMappingFn().getSideInputWindow(currentWindow), encodedWindowOut);
-    } catch (IOException e) {
-      throw new IllegalStateException(e);
-    }
-    ByteString encodedWindow = encodedWindowOut.toByteString();
-
-    StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
-    cacheKeyBuilder.getMultimapSideInputBuilder()
-        .setPtransformId(ptransformId)
-        .setSideInputId(view.getId())
-        .setWindow(encodedWindow);
-    return stateKeyObjectCache.computeIfAbsent(
-        cacheKeyBuilder.build(),
-        key -> sideInputSpec.getViewFn().apply(createMultimapSideInput(
-            view.getId(), encodedWindow, kvCoder.getKeyCoder(), kvCoder.getValueCoder())));
-  }
-
-  private <K, V> MultimapSideInput<K, V> createMultimapSideInput(
-      String sideInputId,
-      ByteString encodedWindow,
-      Coder<K> keyCoder,
-      Coder<V> valueCoder) {
-
-    return new MultimapSideInput<>(
-        beamFnStateClient,
-        processBundleInstructionId.get(),
-        ptransformId,
-        sideInputId,
-        encodedWindow,
-        keyCoder,
-        valueCoder);
-  }
 }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
index 7819726..e2c61de 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
@@ -27,6 +27,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -109,7 +110,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction)
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener)
         throws IOException {
 
       Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers =
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
index c130d4d..efd4a38 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
@@ -22,6 +22,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -37,7 +38,6 @@
  * A factory able to instantiate an appropriate handler for a given PTransform.
  */
 public interface PTransformRunnerFactory<T> {
-
   /**
    * Creates and returns a handler for a given PTransform. Note that the handler must support
    * processing multiple bundles. The handler will be discarded if an error is thrown during element
@@ -60,6 +60,7 @@
    *     registered within this multimap.
    * @param addStartFunction A consumer to register a start bundle handler with.
    * @param addFinishFunction A consumer to register a finish bundle handler with.
+   * @param splitListener A listener to be invoked when the PTransform splits itself.
    */
   T createRunnerForPTransform(
       PipelineOptions pipelineOptions,
@@ -73,7 +74,8 @@
       Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
       Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
       Consumer<ThrowingRunnable> addStartFunction,
-      Consumer<ThrowingRunnable> addFinishFunction)
+      Consumer<ThrowingRunnable> addFinishFunction,
+      BundleSplitListener splitListener)
       throws IOException;
 
   /**
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
new file mode 100644
index 0000000..c8cac40
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+import org.apache.beam.runners.core.OutputAndTimeBoundedSplittableProcessElementInvoker;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.SplittableProcessElementInvoker;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TupleTag;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/** Runs the {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} transform. */
+public class SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>
+    implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<KV<InputT, RestrictionT>> {
+  /** A registrar which provides a factory to handle Java {@link DoFn}s. */
+  @AutoService(PTransformRunnerFactory.Registrar.class)
+  public static class Registrar implements PTransformRunnerFactory.Registrar {
+    @Override
+    public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+      return ImmutableMap.of(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN, new Factory());
+    }
+  }
+
+  static class Factory<InputT, RestrictionT, OutputT>
+      extends DoFnPTransformRunnerFactory<
+          KV<InputT, RestrictionT>,
+          InputT,
+          OutputT,
+          SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>> {
+
+    @Override
+    SplittableProcessElementsRunner<InputT, RestrictionT, OutputT> createRunner(
+        Context<InputT, OutputT> context) {
+      Coder<WindowedValue<KV<InputT, RestrictionT>>> windowedCoder =
+          FullWindowedValueCoder.of(
+              (Coder<KV<InputT, RestrictionT>>) context.inputCoder, context.windowCoder);
+
+      return new SplittableProcessElementsRunner<>(
+          context,
+          windowedCoder,
+          (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+              (Collection) context.tagToConsumer.get(context.mainOutputTag),
+          Iterables.getOnlyElement(context.pTransform.getInputsMap().keySet()));
+    }
+  }
+
+  //////////////////////////////////////////////////////////////////////////////////////////////////
+
+  private final Context<InputT, OutputT> context;
+  private final String mainInputId;
+  private final Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder;
+  private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
+  private final DoFnInvoker<InputT, OutputT> doFnInvoker;
+  private final ScheduledExecutorService executor;
+
+  private FnApiStateAccessor stateAccessor;
+
+  private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+  private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+  SplittableProcessElementsRunner(
+      Context<InputT, OutputT> context,
+      Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder,
+      Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
+      String mainInputId) {
+    this.context = context;
+    this.mainInputId = mainInputId;
+    this.inputCoder = inputCoder;
+    this.mainOutputConsumers = mainOutputConsumers;
+    this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
+    this.doFnInvoker.invokeSetup();
+    this.executor = Executors.newSingleThreadScheduledExecutor();
+
+    this.startBundleContext =
+        context.doFn.new StartBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
+        };
+    this.finishBundleContext =
+        context.doFn.new FinishBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
+
+          @Override
+          public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public <T> void output(
+              TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+            throw new UnsupportedOperationException();
+          }
+        };
+  }
+
+  @Override
+  public void startBundle() {
+    this.stateAccessor =
+        new FnApiStateAccessor(
+            context.pipelineOptions,
+            context.ptransformId,
+            context.processBundleInstructionId,
+            context.tagToSideInputSpecMap,
+            context.beamFnStateClient,
+            context.keyCoder,
+            (Coder<BoundedWindow>) context.windowCoder);
+    doFnInvoker.invokeStartBundle(startBundleContext);
+  }
+
+  @Override
+  public void processElement(WindowedValue<KV<InputT, RestrictionT>> elem) {
+    processElementTyped(elem);
+  }
+
+  private <PositionT, TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
+      void processElementTyped(WindowedValue<KV<InputT, RestrictionT>> elem) {
+    checkArgument(
+        elem.getWindows().size() == 1,
+        "SPLITTABLE_PROCESS_ELEMENTS expects its input to be in 1 window, but got %s windows",
+        elem.getWindows().size());
+    this.stateAccessor.setCurrentWindow(elem.getWindows().iterator().next());
+    WindowedValue<InputT> element = elem.withValue(elem.getValue().getKey());
+    TrackerT tracker = doFnInvoker.invokeNewTracker(elem.getValue().getValue());
+    OutputAndTimeBoundedSplittableProcessElementInvoker<
+            InputT, OutputT, RestrictionT, PositionT, TrackerT>
+        processElementInvoker =
+            new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
+                context.doFn,
+                context.pipelineOptions,
+                new OutputWindowedValue<OutputT>() {
+                  @Override
+                  public void outputWindowedValue(
+                      OutputT output,
+                      Instant timestamp,
+                      Collection<? extends BoundedWindow> windows,
+                      PaneInfo pane) {
+                    outputTo(
+                        mainOutputConsumers, WindowedValue.of(output, timestamp, windows, pane));
+                  }
+
+                  @Override
+                  public <AdditionalOutputT> void outputWindowedValue(
+                      TupleTag<AdditionalOutputT> tag,
+                      AdditionalOutputT output,
+                      Instant timestamp,
+                      Collection<? extends BoundedWindow> windows,
+                      PaneInfo pane) {
+                    Collection<FnDataReceiver<WindowedValue<AdditionalOutputT>>> consumers =
+                        (Collection) context.tagToConsumer.get(tag);
+                    if (consumers == null) {
+                      throw new IllegalArgumentException(
+                          String.format("Unknown output tag %s", tag));
+                    }
+                    outputTo(consumers, WindowedValue.of(output, timestamp, windows, pane));
+                  }
+                },
+                stateAccessor,
+                executor,
+                10000,
+                Duration.standardSeconds(10));
+
+    SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, TrackerT>.Result result =
+        processElementInvoker.invokeProcessElement(doFnInvoker, element, tracker);
+    if (result.getContinuation().shouldResume()) {
+      WindowedValue<KV<InputT, RestrictionT>> primary =
+          element.withValue(KV.of(element.getValue(), tracker.currentRestriction()));
+      WindowedValue<KV<InputT, RestrictionT>> residual =
+          element.withValue(KV.of(element.getValue(), result.getResidualRestriction()));
+      ByteString.Output primaryBytes = ByteString.newOutput();
+      ByteString.Output residualBytes = ByteString.newOutput();
+      try {
+        inputCoder.encode(primary, primaryBytes);
+        inputCoder.encode(residual, residualBytes);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      Application primaryApplication =
+          Application.newBuilder()
+              .setPtransformId(context.ptransformId)
+              .setInputId(mainInputId)
+              .setElement(primaryBytes.toByteString())
+              .build();
+      Application residualApplication =
+          Application.newBuilder()
+              .setPtransformId(context.ptransformId)
+              .setInputId(mainInputId)
+              .setElement(residualBytes.toByteString())
+              .build();
+      context.splitListener.split(
+          ImmutableList.of(primaryApplication),
+          ImmutableList.of(
+              DelayedApplication.newBuilder()
+                  .setApplication(residualApplication)
+                  .setDelaySec(0.001 * result.getContinuation().resumeDelay().getMillis())
+                  .build()));
+    }
+  }
+
+  @Override
+  public void finishBundle() {
+    doFnInvoker.invokeFinishBundle(finishBundleContext);
+  }
+
+  /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
+  private <T> void outputTo(
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
+    try {
+      for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
+        consumer.accept(output);
+      }
+    } catch (Throwable t) {
+      throw UserCodeException.wrap(t);
+    }
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
new file mode 100644
index 0000000..db73447
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.control;
+
+import java.util.List;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+
+/**
+ * Listens to splits happening to a single bundle. See <a
+ * href="https://s.apache.org/beam-breaking-fusion">Breaking the Fusion Barrier</a> for a
+ * discussion of the design.
+ */
+public interface BundleSplitListener {
+  /**
+   * Signals that the current bundle should be split into the given set of primary and residual
+   * roots.
+   *
+   * <p>Primary roots are the new decomposition of the bundle's work into transform applications
+   * that have happened or will happen as part of this bundle (modulo future splits). Residual roots
+   * are a decomposition of work that has been given away by the bundle, so the runner must delegate
+   * it for someone else to execute.
+   */
+  void split(List<Application> primaryRoots, List<DelayedApplication> residualRoots);
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 64d713e..ee907a8 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -46,8 +46,12 @@
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
@@ -144,7 +148,8 @@
       SetMultimap<String, String> pCollectionIdsToConsumingPTransforms,
       ListMultimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
       Consumer<ThrowingRunnable> addStartFunction,
-      Consumer<ThrowingRunnable> addFinishFunction)
+      Consumer<ThrowingRunnable> addFinishFunction,
+      BundleSplitListener splitListener)
       throws IOException {
 
     // Recursively ensure that all consumers of the output PCollection have been created.
@@ -166,7 +171,8 @@
             pCollectionIdsToConsumingPTransforms,
             pCollectionIdsToConsumers,
             addStartFunction,
-            addFinishFunction);
+            addFinishFunction,
+            splitListener);
       }
     }
 
@@ -196,15 +202,12 @@
             processBundleDescriptor.getWindowingStrategiesMap(),
             pCollectionIdsToConsumers,
             addStartFunction,
-            addFinishFunction);
+            addFinishFunction,
+            splitListener);
   }
 
   public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request)
       throws Exception {
-    BeamFnApi.InstructionResponse.Builder response =
-        BeamFnApi.InstructionResponse.newBuilder()
-            .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
-
     String bundleId = request.getProcessBundle().getProcessBundleDescriptorReference();
     BeamFnApi.ProcessBundleDescriptor bundleDescriptor =
         (BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId);
@@ -223,6 +226,8 @@
       }
     }
 
+    ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder();
+
     // Instantiate a State API call handler depending on whether a State Api service descriptor
     // was specified.
     try (HandleStateCallsForBundle beamFnStateClient =
@@ -230,6 +235,23 @@
         ? new BlockTillStateCallsFinish(beamFnStateGrpcClientCache.forApiServiceDescriptor(
             bundleDescriptor.getStateApiServiceDescriptor()))
         : new FailAllStateCallsForBundle(request.getProcessBundle())) {
+      Multimap<String, Application> allPrimaries = ArrayListMultimap.create();
+      Multimap<String, DelayedApplication> allResiduals = ArrayListMultimap.create();
+      BundleSplitListener splitListener =
+          (List<Application> primaries, List<DelayedApplication> residuals) -> {
+            // Reset primaries and accumulate residuals.
+            Multimap<String, Application> newPrimaries = ArrayListMultimap.create();
+            for (Application primary : primaries) {
+              newPrimaries.put(primary.getPtransformId(), primary);
+            }
+            allPrimaries.clear();
+            allPrimaries.putAll(newPrimaries);
+
+            for (DelayedApplication residual : residuals) {
+              allResiduals.put(residual.getApplication().getPtransformId(), residual);
+            }
+          };
+
       // Create a BeamFnStateClient
       for (Map.Entry<String, RunnerApi.PTransform> entry
           : bundleDescriptor.getTransformsMap().entrySet()) {
@@ -251,7 +273,8 @@
             pCollectionIdsToConsumingPTransforms,
             pCollectionIdsToConsumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            splitListener);
       }
 
       // Already in reverse topological order so we don't need to do anything.
@@ -265,9 +288,16 @@
         LOG.debug("Finishing function {}", finishFunction);
         finishFunction.run();
       }
+      if (!allPrimaries.isEmpty()) {
+        response.setSplit(
+            BundleSplit.newBuilder()
+                .addAllPrimaryRoots(allPrimaries.values())
+                .addAllResidualRoots(allResiduals.values())
+                .build());
+      }
     }
 
-    return response;
+    return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response);
   }
 
   /**
@@ -355,7 +385,8 @@
         Map<String, WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) {
       String message =
           String.format(
               "No factory registered for %s, known factories %s",
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
new file mode 100644
index 0000000..e7a043b
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.MapState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
+import org.apache.beam.sdk.state.SetState;
+import org.apache.beam.sdk.state.StateBinder;
+import org.apache.beam.sdk.state.StateContext;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.CombineFnUtil;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+
+/** Provides access to side inputs and state via a {@link BeamFnStateClient}. */
+public class FnApiStateAccessor implements SideInputReader, StateBinder {
+  private final PipelineOptions pipelineOptions;
+  private final Map<StateKey, Object> stateKeyObjectCache;
+  private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
+  private final BeamFnStateClient beamFnStateClient;
+  private final String ptransformId;
+  private final Supplier<String> processBundleInstructionId;
+  private final Collection<ThrowingRunnable> stateFinalizers;
+
+  private final Coder<?> keyCoder;
+  private final Coder<BoundedWindow> windowCoder;
+
+  private WindowedValue<?> currentElement;
+  private BoundedWindow currentWindow;
+  private ByteString encodedCurrentKey;
+  private ByteString encodedCurrentWindow;
+
+  public FnApiStateAccessor(
+      PipelineOptions pipelineOptions,
+      String ptransformId,
+      Supplier<String> processBundleInstructionId,
+      Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
+      BeamFnStateClient beamFnStateClient,
+      Coder<?> keyCoder,
+      Coder<BoundedWindow> windowCoder) {
+    this.pipelineOptions = pipelineOptions;
+    this.stateKeyObjectCache = Maps.newHashMap();
+    this.sideInputSpecMap = sideInputSpecMap;
+    this.beamFnStateClient = beamFnStateClient;
+    this.ptransformId = ptransformId;
+    this.processBundleInstructionId = processBundleInstructionId;
+    this.stateFinalizers = new ArrayList<>();
+
+    this.keyCoder = keyCoder;
+    this.windowCoder = windowCoder;
+  }
+
+  public void setCurrentElement(WindowedValue<?> currentElement) {
+    this.currentElement = currentElement;
+    this.encodedCurrentKey = null;
+  }
+
+  public void setCurrentWindow(BoundedWindow currentWindow) {
+    this.currentWindow = currentWindow;
+    this.encodedCurrentWindow = null;
+  }
+
+  @Override
+  @Nullable
+  public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+    TupleTag<?> tag = view.getTagInternal();
+
+    SideInputSpec sideInputSpec = sideInputSpecMap.get(tag);
+    checkArgument(sideInputSpec != null, "Attempting to access unknown side input %s.", view);
+    KvCoder<?, ?> kvCoder = (KvCoder) sideInputSpec.getCoder();
+
+    ByteString.Output encodedWindowOut = ByteString.newOutput();
+    try {
+      sideInputSpec
+          .getWindowCoder()
+          .encode(sideInputSpec.getWindowMappingFn().getSideInputWindow(window), encodedWindowOut);
+    } catch (IOException e) {
+      throw new IllegalStateException(e);
+    }
+    ByteString encodedWindow = encodedWindowOut.toByteString();
+
+    StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
+    cacheKeyBuilder
+        .getMultimapSideInputBuilder()
+        .setPtransformId(ptransformId)
+        .setSideInputId(tag.getId())
+        .setWindow(encodedWindow);
+    return (T)
+        stateKeyObjectCache.computeIfAbsent(
+            cacheKeyBuilder.build(),
+            key ->
+                sideInputSpec
+                    .getViewFn()
+                    .apply(
+                        new MultimapSideInput<>(
+                            beamFnStateClient,
+                            processBundleInstructionId.get(),
+                            ptransformId,
+                            tag.getId(),
+                            encodedWindow,
+                            kvCoder.getKeyCoder(),
+                            kvCoder.getValueCoder())));
+  }
+
+  @Override
+  public <T> boolean contains(PCollectionView<T> view) {
+    return sideInputSpecMap.containsKey(view.getTagInternal());
+  }
+
+  @Override
+  public boolean isEmpty() {
+    return sideInputSpecMap.isEmpty();
+  }
+
+  @Override
+  public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
+    return (ValueState<T>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                return new ValueState<T>() {
+                  private final BagUserState<T> impl = createBagUserState(id, coder);
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+
+                  @Override
+                  public void write(T input) {
+                    impl.clear();
+                    impl.append(input);
+                  }
+
+                  @Override
+                  public T read() {
+                    Iterator<T> value = impl.get().iterator();
+                    if (value.hasNext()) {
+                      return value.next();
+                    } else {
+                      return null;
+                    }
+                  }
+
+                  @Override
+                  public ValueState<T> readLater() {
+                    // TODO: Support prefetching.
+                    return this;
+                  }
+                };
+              }
+            });
+  }
+
+  @Override
+  public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
+    return (BagState<T>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                return new BagState<T>() {
+                  private final BagUserState<T> impl = createBagUserState(id, elemCoder);
+
+                  @Override
+                  public void add(T value) {
+                    impl.append(value);
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> isEmpty() {
+                    return ReadableStates.immediate(!impl.get().iterator().hasNext());
+                  }
+
+                  @Override
+                  public Iterable<T> read() {
+                    return impl.get();
+                  }
+
+                  @Override
+                  public BagState<T> readLater() {
+                    // TODO: Support prefetching.
+                    return this;
+                  }
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+                };
+              }
+            });
+  }
+
+  @Override
+  public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
+    throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+  }
+
+  @Override
+  public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+      String id,
+      StateSpec<MapState<KeyT, ValueT>> spec,
+      Coder<KeyT> mapKeyCoder,
+      Coder<ValueT> mapValueCoder) {
+    throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+  }
+
+  @Override
+  public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
+      String id,
+      StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+      Coder<AccumT> accumCoder,
+      CombineFn<ElementT, AccumT, ResultT> combineFn) {
+    return (CombiningState<ElementT, AccumT, ResultT>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                // TODO: Support squashing accumulators depending on whether we know of all
+                // remote accumulators and local accumulators or just local accumulators.
+                return new CombiningState<ElementT, AccumT, ResultT>() {
+                  private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
+
+                  @Override
+                  public AccumT getAccum() {
+                    Iterator<AccumT> iterator = impl.get().iterator();
+                    if (iterator.hasNext()) {
+                      return iterator.next();
+                    }
+                    return combineFn.createAccumulator();
+                  }
+
+                  @Override
+                  public void addAccum(AccumT accum) {
+                    Iterator<AccumT> iterator = impl.get().iterator();
+
+                    // Only merge if there was a prior value
+                    if (iterator.hasNext()) {
+                      accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
+                      // Since there was a prior value, we need to clear.
+                      impl.clear();
+                    }
+
+                    impl.append(accum);
+                  }
+
+                  @Override
+                  public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+                    return combineFn.mergeAccumulators(accumulators);
+                  }
+
+                  @Override
+                  public CombiningState<ElementT, AccumT, ResultT> readLater() {
+                    return this;
+                  }
+
+                  @Override
+                  public ResultT read() {
+                    Iterator<AccumT> iterator = impl.get().iterator();
+                    if (iterator.hasNext()) {
+                      return combineFn.extractOutput(iterator.next());
+                    }
+                    return combineFn.defaultValue();
+                  }
+
+                  @Override
+                  public void add(ElementT value) {
+                    AccumT newAccumulator = combineFn.addInput(getAccum(), value);
+                    impl.clear();
+                    impl.append(newAccumulator);
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> isEmpty() {
+                    return ReadableStates.immediate(!impl.get().iterator().hasNext());
+                  }
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+                };
+              }
+            });
+  }
+
+  @Override
+  public <ElementT, AccumT, ResultT>
+      CombiningState<ElementT, AccumT, ResultT> bindCombiningWithContext(
+          String id,
+          StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+          Coder<AccumT> accumCoder,
+          CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
+    return (CombiningState<ElementT, AccumT, ResultT>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            key ->
+                bindCombining(
+                    id,
+                    spec,
+                    accumCoder,
+                    CombineFnUtil.bindContext(
+                        combineFn,
+                        new StateContext<BoundedWindow>() {
+                          @Override
+                          public PipelineOptions getPipelineOptions() {
+                            return pipelineOptions;
+                          }
+
+                          @Override
+                          public <T> T sideInput(PCollectionView<T> view) {
+                            return get(view, currentWindow);
+                          }
+
+                          @Override
+                          public BoundedWindow window() {
+                            return currentWindow;
+                          }
+                        })));
+  }
+
+  /**
+   * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing and is
+   *     waiting on resolution of BEAM-2535.
+   */
+  @Override
+  @Deprecated
+  public WatermarkHoldState bindWatermark(
+      String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
+    throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
+  }
+
+  private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> valueCoder) {
+    BagUserState<T> rval =
+        new BagUserState<>(
+            beamFnStateClient,
+            processBundleInstructionId.get(),
+            ptransformId,
+            stateId,
+            encodedCurrentWindow,
+            encodedCurrentKey,
+            valueCoder);
+    stateFinalizers.add(rval::asyncClose);
+    return rval;
+  }
+
+  private StateKey createBagUserStateKey(String stateId) {
+    cacheEncodedKeyAndWindowForKeyedContext();
+    StateKey.Builder builder = StateKey.newBuilder();
+    builder
+        .getBagUserStateBuilder()
+        .setWindow(encodedCurrentWindow)
+        .setKey(encodedCurrentKey)
+        .setPtransformId(ptransformId)
+        .setUserStateId(stateId);
+    return builder.build();
+  }
+
+  /**
+   * Memoizes an encoded key and window for the current element being processed saving on the
+   * encoding cost of the key and window across multiple state cells for the lifetime of {@link
+   * DoFn.ProcessElement}
+   *
+   * <p>This should only be called during {@link DoFn.ProcessElement}.
+   */
+  private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
+    if (encodedCurrentKey == null) {
+      checkState(
+          currentElement.getValue() instanceof KV,
+          "Accessing state in unkeyed context. Current element is not a KV: %s.",
+          currentElement);
+      checkState(keyCoder != null, "Accessing state in unkeyed context, no key coder available");
+
+      ByteString.Output encodedKeyOut = ByteString.newOutput();
+      try {
+        ((Coder) keyCoder).encode(((KV<?, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
+      } catch (IOException e) {
+        throw new IllegalStateException(e);
+      }
+      encodedCurrentKey = encodedKeyOut.toByteString();
+    }
+
+    if (encodedCurrentWindow == null) {
+      ByteString.Output encodedWindowOut = ByteString.newOutput();
+      try {
+        windowCoder.encode(currentWindow, encodedWindowOut);
+      } catch (IOException e) {
+        throw new IllegalStateException(e);
+      }
+      encodedCurrentWindow = encodedWindowOut.toByteString();
+    }
+  }
+
+  public void finalizeState() {
+    // Persist all dirty state cells
+    try {
+      for (ThrowingRunnable runnable : stateFinalizers) {
+        runnable.run();
+      }
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new IllegalStateException(e);
+    } catch (Exception e) {
+      throw new IllegalStateException(e);
+    }
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
new file mode 100644
index 0000000..8aec3a1
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
+
+/**
+ * A specification for side inputs containing a value {@link Coder}, the window {@link Coder},
+ * {@link ViewFn}, and the {@link WindowMappingFn}.
+ *
+ * @param <W>
+ */
+@AutoValue
+public abstract class SideInputSpec<W extends BoundedWindow> {
+  public static <W extends BoundedWindow> SideInputSpec create(
+      Coder<?> coder,
+      Coder<W> windowCoder,
+      ViewFn<?, ?> viewFn,
+      WindowMappingFn<W> windowMappingFn) {
+    return new AutoValue_SideInputSpec<>(
+        coder, windowCoder, viewFn, windowMappingFn);
+  }
+
+  abstract Coder<?> getCoder();
+
+  abstract Coder<W> getWindowCoder();
+
+  abstract ViewFn<?, ?> getViewFn();
+
+  abstract WindowMappingFn<W> getWindowMappingFn();
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
index d72753b..9871c37 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
@@ -201,7 +201,8 @@
             null /* windowingStrategies */,
             receivers,
             null /* addStartFunction */,
-            null /* addFinishFunction */);
+            null, /* addFinishFunction */
+            null /* splitListener */);
 
     WindowedValue<Integer> value =
         WindowedValue.of(
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index f138677..4b4c64f 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -152,7 +152,8 @@
         COMPONENTS.getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     verifyZeroInteractions(mockBeamFnDataClient);
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
index 12540c9..a70aedd 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
@@ -141,7 +141,8 @@
         COMPONENTS.getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     verifyZeroInteractions(mockBeamFnDataClient);
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
index 044de10..e438bd5 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
@@ -151,7 +151,8 @@
         Collections.emptyMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     // This is testing a deprecated way of running sources and should be removed
     // once all source definitions are instead propagated along the input edge.
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
index 9b1bd75..3b58517 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
@@ -87,7 +87,8 @@
         Collections.emptyMap(),
         consumers,
         null /* addStartFunction */,
-        null /* addFinishFunction */);
+        null, /* addFinishFunction */
+        null /* splitListener */);
 
     mainOutputValues.clear();
     assertThat(consumers.keySet(), containsInAnyOrder(
@@ -149,7 +150,8 @@
             Collections.emptyMap(),
             consumers,
             null /* addStartFunction */,
-            null /* addFinishFunction */);
+            null, /* addFinishFunction */
+            null /* splitListener */);
 
     mainOutputValues.clear();
     assertThat(consumers.keySet(), containsInAnyOrder("inputATarget", "mainOutputTarget"));
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 85aa564..ee04cf8 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -18,8 +18,6 @@
 
 package org.apache.beam.fn.harness;
 
-import static com.google.common.base.Preconditions.checkState;
-import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
 import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -27,29 +25,23 @@
 import static org.hamcrest.Matchers.hasSize;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
 
 import com.google.common.base.Suppliers;
 import com.google.common.collect.HashMultimap;
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Multimap;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
+import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
-import java.util.ServiceLoader;
-import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
 import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
 import org.apache.beam.runners.core.construction.SdkComponents;
 import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -60,8 +52,6 @@
 import org.apache.beam.sdk.state.StateSpecs;
 import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
-import org.apache.beam.sdk.transforms.CombineWithContext.Context;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
@@ -73,15 +63,13 @@
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.CoderUtils;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.hamcrest.collection.IsMapContaining;
+import org.apache.beam.sdk.values.TupleTagList;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Test;
@@ -90,137 +78,10 @@
 
 /** Tests for {@link FnApiDoFnRunner}. */
 @RunWith(JUnit4.class)
-public class FnApiDoFnRunnerTest {
+public class FnApiDoFnRunnerTest implements Serializable {
 
   public static final String TEST_PTRANSFORM_ID = "pTransformId";
 
-  private static class TestDoFn extends DoFn<String, String> {
-    private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
-    private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
-
-    private enum State {
-      NOT_SET_UP,
-      OUTSIDE_BUNDLE,
-      INSIDE_BUNDLE,
-    }
-
-    private State state = State.NOT_SET_UP;
-
-    private BoundedWindow window;
-
-    @Setup
-    public void setUp() {
-      checkState(State.NOT_SET_UP.equals(state), "Unexpected state: %s", state);
-      state = State.OUTSIDE_BUNDLE;
-    }
-
-    // No testing for TearDown - it's currently not supported by FnHarness.
-
-    @StartBundle
-    public void startBundle() {
-      checkState(State.OUTSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
-      state = State.INSIDE_BUNDLE;
-    }
-
-    @ProcessElement
-    public void processElement(ProcessContext context, BoundedWindow window) {
-      checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
-      context.output("MainOutput" + context.element());
-      context.output(additionalOutput, "AdditionalOutput" + context.element());
-      this.window = window;
-    }
-
-    @FinishBundle
-    public void finishBundle(FinishBundleContext context) {
-      checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
-      state = State.OUTSIDE_BUNDLE;
-      if (window != null) {
-        context.output("FinishBundle", window.maxTimestamp(), window);
-        window = null;
-      }
-    }
-  }
-
-  /**
-   * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs
-   * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs
-   * are directed to the correct consumers.
-   */
-  @Test
-  public void testCreatingAndProcessingDoFn() throws Exception {
-    String pTransformId = "pTransformId";
-
-    DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
-        new TestDoFn(),
-        WindowingStrategy.globalDefault(),
-        ImmutableList.of(),
-        StringUtf8Coder.of(),
-        TestDoFn.mainOutput);
-    RunnerApi.FunctionSpec functionSpec =
-        RunnerApi.FunctionSpec.newBuilder()
-            .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
-            .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
-            .build();
-    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
-        .setSpec(functionSpec)
-        .putInputs("inputA", "inputATarget")
-        .putInputs("inputB", "inputBTarget")
-        .putOutputs(TestDoFn.mainOutput.getId(), "mainOutputTarget")
-        .putOutputs(TestDoFn.additionalOutput.getId(), "additionalOutputTarget")
-        .build();
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
-    Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
-    consumers.put("mainOutputTarget",
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-    consumers.put("additionalOutputTarget",
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
-    List<ThrowingRunnable> startFunctions = new ArrayList<>();
-    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
-        PipelineOptionsFactory.create(),
-        null /* beamFnDataClient */,
-        null /* beamFnStateClient */,
-        pTransformId,
-        pTransform,
-        Suppliers.ofInstance("57L")::get,
-        Collections.emptyMap(),
-        Collections.emptyMap(),
-        Collections.emptyMap(),
-        consumers,
-        startFunctions::add,
-        finishFunctions::add);
-
-    Iterables.getOnlyElement(startFunctions).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(
-        "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget"));
-
-    Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1"));
-    Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2"));
-    Iterables.getOnlyElement(consumers.get("inputBTarget")).accept(valueInGlobalWindow("B"));
-    assertThat(mainOutputValues, contains(
-        valueInGlobalWindow("MainOutputA1"),
-        valueInGlobalWindow("MainOutputA2"),
-        valueInGlobalWindow("MainOutputB")));
-    assertThat(additionalOutputValues, contains(
-        valueInGlobalWindow("AdditionalOutputA1"),
-        valueInGlobalWindow("AdditionalOutputA2"),
-        valueInGlobalWindow("AdditionalOutputB")));
-    mainOutputValues.clear();
-    additionalOutputValues.clear();
-
-    Iterables.getOnlyElement(finishFunctions).run();
-    assertThat(
-        mainOutputValues,
-        contains(
-            timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp())));
-    mainOutputValues.clear();
-  }
-
   private static class ConcatCombineFn extends CombineFn<String, String, String> {
     @Override
     public String createAccumulator() {
@@ -247,33 +108,6 @@
     }
   }
 
-  private static class ConcatCombineFnWithContext
-      extends CombineFnWithContext<String, String, String> {
-    @Override
-    public String createAccumulator(Context c) {
-      return "";
-    }
-
-    @Override
-    public String addInput(String accumulator, String input, Context c) {
-      return accumulator.concat(input);
-    }
-
-    @Override
-    public String mergeAccumulators(Iterable<String> accumulators, Context c) {
-      StringBuilder builder = new StringBuilder();
-      for (String value : accumulators) {
-        builder.append(value);
-      }
-      return builder.toString();
-    }
-
-    @Override
-    public String extractOutput(String accumulator, Context c) {
-      return accumulator;
-    }
-  }
-
   private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> {
     private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
     private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
@@ -287,17 +121,12 @@
     @StateId("combine")
     private final StateSpec<CombiningState<String, String, String>> combiningStateSpec =
         StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn());
-    @StateId("combineWithContext")
-    private final StateSpec<CombiningState<String, String, String>> combiningWithContextStateSpec =
-        StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFnWithContext());
 
     @ProcessElement
     public void processElement(ProcessContext context,
         @StateId("value") ValueState<String> valueState,
         @StateId("bag") BagState<String> bagState,
-        @StateId("combine") CombiningState<String, String, String> combiningState,
-        @StateId("combineWithContext")
-            CombiningState<String, String, String> combiningWithContextState) {
+        @StateId("combine") CombiningState<String, String, String> combiningState) {
       context.output("value:" + valueState.read());
       valueState.write(context.element().getValue());
 
@@ -306,41 +135,34 @@
 
       context.output("combine:" + combiningState.read());
       combiningState.add(context.element().getValue());
-
-      context.output("combineWithContext:" + combiningWithContextState.read());
-      combiningWithContextState.add(context.element().getValue());
     }
   }
 
   @Test
   public void testUsingUserState() throws Exception {
-    DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
-        new TestStatefulDoFn(),
-        WindowingStrategy.globalDefault(),
-        ImmutableList.of(),
-        KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()),
-        new TupleTag<>("mainOutput"));
-    RunnerApi.FunctionSpec functionSpec =
-        RunnerApi.FunctionSpec.newBuilder()
-            .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
-            .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
-            .build();
-    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
-        .setSpec(functionSpec)
-        .putInputs("input", "inputTarget")
-        .putOutputs("mainOutput", "mainOutputTarget")
-        .build();
+    Pipeline p = Pipeline.create();
+    PCollection<KV<String, String>> valuePCollection =
+        p.apply(Create.of(KV.of("unused", "unused")));
+    PCollection<String> outputPCollection =
+        valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(new TestStatefulDoFn()));
+
+    SdkComponents sdkComponents = SdkComponents.create();
+    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+    String outputPCollectionId =
+        sdkComponents.registerPCollection(outputPCollection);
+    RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
+        pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
 
     FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of(
         bagUserStateKey("value", "X"), encode("X0"),
         bagUserStateKey("bag", "X"), encode("X0"),
-        bagUserStateKey("combine", "X"), encode("X0"),
-        bagUserStateKey("combineWithContext", "X"), encode("X0")
+        bagUserStateKey("combine", "X"), encode("X0")
     ));
 
     List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
     Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
-    consumers.put("mainOutputTarget",
+    consumers.put(outputPCollectionId,
         (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
     List<ThrowingRunnable> startFunctions = new ArrayList<>();
     List<ThrowingRunnable> finishFunctions = new ArrayList<>();
@@ -352,22 +174,23 @@
         TEST_PTRANSFORM_ID,
         pTransform,
         Suppliers.ofInstance("57L")::get,
-        Collections.emptyMap(),
-        Collections.emptyMap(),
-        Collections.emptyMap(),
+        pProto.getComponents().getPcollectionsMap(),
+        pProto.getComponents().getCodersMap(),
+        pProto.getComponents().getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     Iterables.getOnlyElement(startFunctions).run();
     mainOutputValues.clear();
 
-    assertThat(consumers.keySet(), containsInAnyOrder("inputTarget", "mainOutputTarget"));
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
 
     // Ensure that bag user state that is initially empty or populated works.
     // Ensure that the key order does not matter when we traverse over KV pairs.
     FnDataReceiver<WindowedValue<?>> mainInput =
-        Iterables.getOnlyElement(consumers.get("inputTarget"));
+        Iterables.getOnlyElement(consumers.get(inputPCollectionId));
     mainInput.accept(valueInGlobalWindow(KV.of("X", "X1")));
     mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1")));
     mainInput.accept(valueInGlobalWindow(KV.of("X", "X2")));
@@ -376,19 +199,15 @@
         valueInGlobalWindow("value:X0"),
         valueInGlobalWindow("bag:[X0]"),
         valueInGlobalWindow("combine:X0"),
-        valueInGlobalWindow("combineWithContext:X0"),
         valueInGlobalWindow("value:null"),
         valueInGlobalWindow("bag:[]"),
         valueInGlobalWindow("combine:"),
-        valueInGlobalWindow("combineWithContext:"),
         valueInGlobalWindow("value:X1"),
         valueInGlobalWindow("bag:[X0, X1]"),
         valueInGlobalWindow("combine:X0X1"),
-        valueInGlobalWindow("combineWithContext:X0X1"),
         valueInGlobalWindow("value:Y1"),
         valueInGlobalWindow("bag:[Y1]"),
-        valueInGlobalWindow("combine:Y1"),
-        valueInGlobalWindow("combineWithContext:Y1")));
+        valueInGlobalWindow("combine:Y1")));
     mainOutputValues.clear();
 
     Iterables.getOnlyElement(finishFunctions).run();
@@ -399,11 +218,9 @@
             .put(bagUserStateKey("value", "X"), encode("X2"))
             .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
             .put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
-            .put(bagUserStateKey("combineWithContext", "X"), encode("X0X1X2"))
             .put(bagUserStateKey("value", "Y"), encode("Y2"))
             .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
             .put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
-            .put(bagUserStateKey("combineWithContext", "Y"), encode("Y1Y2"))
             .build(),
         fakeClient.getData());
     mainOutputValues.clear();
@@ -425,13 +242,17 @@
     private final PCollectionView<String> defaultSingletonSideInput;
     private final PCollectionView<String> singletonSideInput;
     private final PCollectionView<Iterable<String>> iterableSideInput;
+    private final TupleTag<String> additionalOutput;
+
     private TestSideInputDoFn(
         PCollectionView<String> defaultSingletonSideInput,
         PCollectionView<String> singletonSideInput,
-        PCollectionView<Iterable<String>> iterableSideInput) {
+        PCollectionView<Iterable<String>> iterableSideInput,
+        TupleTag<String> additionalOutput) {
       this.defaultSingletonSideInput = defaultSingletonSideInput;
       this.singletonSideInput = singletonSideInput;
       this.iterableSideInput = iterableSideInput;
+      this.additionalOutput = additionalOutput;
     }
 
     @ProcessElement
@@ -441,6 +262,7 @@
       for (String sideInputValue : context.sideInput(iterableSideInput)) {
         context.output(context.element() + ":" + sideInputValue);
       }
+      context.output(additionalOutput, context.element() + ":additional");
     }
   }
 
@@ -453,21 +275,31 @@
     PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
     PCollectionView<Iterable<String>> iterableSideInputView =
         valuePCollection.apply(View.asIterable());
-    PCollection<String> outputPCollection = valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
-        new TestSideInputDoFn(
-            defaultSingletonSideInputView,
-            singletonSideInputView,
-            iterableSideInputView))
-        .withSideInputs(
-            defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView));
+    TupleTag<String> mainOutput = new TupleTag<String>("main") {};
+    TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
+    PCollectionTuple outputPCollection =
+        valuePCollection.apply(
+            TEST_PTRANSFORM_ID,
+            ParDo.of(
+                    new TestSideInputDoFn(
+                        defaultSingletonSideInputView,
+                        singletonSideInputView,
+                        iterableSideInputView,
+                        additionalOutput))
+                .withSideInputs(
+                    defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView)
+                .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
 
     SdkComponents sdkComponents = SdkComponents.create();
     RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
     String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+    String outputPCollectionId =
+        sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
+    String additionalPCollectionId =
+        sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
 
-    RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
-        pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+    RunnerApi.PTransform pTransform =
+        pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID);
 
     ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
         multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
@@ -478,13 +310,16 @@
     FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
 
     List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+    List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
     Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
-    consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+    consumers.put(outputPCollectionId,
         (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+    consumers.put(additionalPCollectionId,
+        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
     List<ThrowingRunnable> startFunctions = new ArrayList<>();
     List<ThrowingRunnable> finishFunctions = new ArrayList<>();
 
-    new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+    new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
         PipelineOptionsFactory.create(),
         null /* beamFnDataClient */,
         fakeClient,
@@ -496,12 +331,15 @@
         pProto.getComponents().getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     Iterables.getOnlyElement(startFunctions).run();
     mainOutputValues.clear();
 
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+    assertThat(
+        consumers.keySet(),
+        containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
 
     // Ensure that bag user state that is initially empty or populated works.
     // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
@@ -520,6 +358,9 @@
         valueInGlobalWindow("Y:iterableValue1"),
         valueInGlobalWindow("Y:iterableValue2"),
         valueInGlobalWindow("Y:iterableValue3")));
+    assertThat(
+        additionalOutputValues,
+        contains(valueInGlobalWindow("X:additional"), valueInGlobalWindow("Y:additional")));
     mainOutputValues.clear();
 
     Iterables.getOnlyElement(finishFunctions).run();
@@ -589,7 +430,7 @@
     List<ThrowingRunnable> startFunctions = new ArrayList<>();
     List<ThrowingRunnable> finishFunctions = new ArrayList<>();
 
-    new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+    new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
         PipelineOptionsFactory.create(),
         null /* beamFnDataClient */,
         fakeClient,
@@ -601,7 +442,8 @@
         pProto.getComponents().getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     Iterables.getOnlyElement(startFunctions).run();
     mainOutputValues.clear();
@@ -658,17 +500,4 @@
     }
     return out.toByteString();
   }
-
-  @Test
-  public void testRegistration() {
-    for (Registrar registrar :
-        ServiceLoader.load(Registrar.class)) {
-      if (registrar instanceof FnApiDoFnRunner.Registrar) {
-        assertThat(registrar.getPTransformRunnerFactories(),
-            IsMapContaining.hasKey(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN));
-        return;
-      }
-    }
-    fail("Expected registrar not found.");
-  }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
index 906bdeb..93ca808 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
@@ -82,7 +82,8 @@
             Collections.emptyMap(),
             consumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            null /* splitListener */);
 
     assertThat(startFunctions, empty());
     assertThat(finishFunctions, empty());
@@ -116,7 +117,8 @@
             Collections.emptyMap(),
             consumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            null /* splitListener */);
 
     assertThat(startFunctions, empty());
     assertThat(finishFunctions, empty());
@@ -150,7 +152,8 @@
             Collections.emptyMap(),
             consumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            null /* splitListener */);
 
     assertThat(startFunctions, empty());
     assertThat(finishFunctions, empty());
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 48d56fd..eafef99 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -50,6 +50,7 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.function.ThrowingConsumer;
 import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -115,7 +116,8 @@
             windowingStrategies,
             pCollectionIdsToConsumers,
             addStartFunction,
-            addFinishFunction) -> {
+            addFinishFunction,
+            splitListener) -> {
           assertThat(processBundleInstructionId.get(), equalTo("999L"));
 
           transformsProcessed.add(pTransform);
@@ -176,7 +178,8 @@
                     windowingStrategies,
                     pCollectionIdsToConsumers,
                     addStartFunction,
-                    addFinishFunction) -> {
+                    addFinishFunction,
+                    splitListener) -> {
                   thrown.expect(IllegalStateException.class);
                   thrown.expectMessage("TestException");
                   throw new IllegalStateException("TestException");
@@ -217,7 +220,8 @@
                         windowingStrategies,
                         pCollectionIdsToConsumers,
                         addStartFunction,
-                        addFinishFunction) -> {
+                        addFinishFunction,
+                        splitListener) -> {
                       thrown.expect(IllegalStateException.class);
                       thrown.expectMessage("TestException");
                       addStartFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -261,7 +265,8 @@
                         windowingStrategies,
                         pCollectionIdsToConsumers,
                         addStartFunction,
-                        addFinishFunction) -> {
+                        addFinishFunction,
+                        splitListener) -> {
                       thrown.expect(IllegalStateException.class);
                       thrown.expectMessage("TestException");
                       addFinishFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -335,10 +340,11 @@
               Supplier<String> processBundleInstructionId,
               Map<String, PCollection> pCollections,
               Map<String, Coder> coders,
-              Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+              Map<String, WindowingStrategy> windowingStrategies,
               Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
               Consumer<ThrowingRunnable> addStartFunction,
-              Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+              Consumer<ThrowingRunnable> addFinishFunction,
+              BundleSplitListener splitListener) throws IOException {
             addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
             return null;
           }
@@ -386,10 +392,11 @@
               Supplier<String> processBundleInstructionId,
               Map<String, PCollection> pCollections,
               Map<String, Coder> coders,
-              Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+              Map<String, WindowingStrategy> windowingStrategies,
               Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
               Consumer<ThrowingRunnable> addStartFunction,
-              Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+              Consumer<ThrowingRunnable> addFinishFunction,
+              BundleSplitListener splitListener) throws IOException {
             addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
             return null;
           }